utils.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
from collections import Counter
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any

from omegaconf import OmegaConf
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config, get_hf_file_to_dict
from vllm.transformers_utils.repo_utils import file_or_path_exists

from vllm_omni.entrypoints.stage_utils import _to_dict
from vllm_omni.platforms import current_omni_platform

# Get the project root directory (2 levels up from this file)
PROJECT_ROOT = Path(__file__).parent.parent.parent

logger = init_logger(__name__)


def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None:
    """Inject connector configuration into stage engine arguments."""
    # Prepare omni_kv_config dict
    omni_conf_dict = {}
    try:
        # Access engine_args safely (might be OmegaConf or dict)
        existing_args = stage.engine_args
        if hasattr(existing_args, "get"):
            _oc = existing_args.get("omni_kv_config", None)
            if _oc:
                if hasattr(_oc, "items"):  # dict-like
                    omni_conf_dict = dict(_oc)
                else:  # object?
                    omni_conf_dict = _to_dict(_oc)
    except Exception:
        omni_conf_dict = {}

    # Inject connector info
    omni_conf_dict["connector_config"] = omni_conn_cfg
    omni_conf_dict["omni_from_stage"] = omni_from
    omni_conf_dict["omni_to_stage"] = omni_to

    # Write back to engine_args
    try:
        if hasattr(stage.engine_args, "__setitem__"):
            stage.engine_args["omni_kv_config"] = omni_conf_dict
        else:
            setattr(stage.engine_args, "omni_kv_config", omni_conf_dict)
    except Exception as e:
        # Fallback for OmegaConf or similar if direct set fails?
        logger.error(f"Failed to inject omni connector config into stage: {e}")


def _try_get_class_name_from_diffusers_config(model: str) -> str | None:
    """Try to get class name from diffusers model configuration files.

    Args:
        model: Model name or path

    Returns:
        Model type string if found, None otherwise
    """
    model_index = get_hf_file_to_dict("model_index.json", model, revision=None)
    if model_index and isinstance(model_index, dict) and "_class_name" in model_index:
        logger.debug(f"Found model_type '{model_index['_class_name']}' in model_index.json")
        return model_index["_class_name"]

    return None


def _convert_dataclasses_to_dict(obj: Any) -> Any:
    """Recursively convert non-serializable objects to OmegaConf-compatible types.

    This is needed because OmegaConf cannot handle:
    - Dataclass objects with Literal type annotations (e.g., StructuredOutputsConfig)
    - Counter objects (from collections or vllm.utils)
    - Set objects
    - Other non-primitive types
    """
    # IMPORTANT: Check Counter BEFORE dict, since Counter is a subclass of dict
    # Handle Counter objects (convert to dict)
    # Check by class name first to catch both collections.Counter and vllm.utils.Counter
    if hasattr(obj, "__class__") and obj.__class__.__name__ == "Counter":
        try:
            return dict(obj)
        except (TypeError, ValueError):
            # If Counter can't be converted to dict, return empty dict
            return {}
    # Also check isinstance for collections.Counter (must be before dict check)
    if isinstance(obj, Counter):
        return dict(obj)
    # Handle set objects (convert to list)
    if isinstance(obj, set):
        return list(obj)
    # Handle dataclass objects
    # Note: asdict() recursively converts nested dataclasses but not Counter objects,
    # so we need to recursively process the result
    if is_dataclass(obj):
        result = asdict(obj)
        # Recursively process the result to convert any Counter objects
        return _convert_dataclasses_to_dict(result)
    # Handle dictionaries (recurse into values)
    # Note: This must come AFTER Counter check since Counter is a dict subclass
    if isinstance(obj, dict):
        return {k: _convert_dataclasses_to_dict(v) for k, v in obj.items()}
    # Handle lists and tuples (recurse into items)
    if isinstance(obj, (list, tuple)):
        return type(obj)(_convert_dataclasses_to_dict(item) for item in obj)
    # Try to convert any dict-like object (has keys/values methods) to dict
    if hasattr(obj, "keys") and hasattr(obj, "values") and not isinstance(obj, (str, bytes)):
        try:
            return {k: _convert_dataclasses_to_dict(v) for k, v in obj.items()}
        except (TypeError, ValueError, AttributeError):
            # If conversion fails, return as-is
            return obj
    # Primitive types and other objects that OmegaConf can handle
    return obj


def resolve_model_config_path(model: str) -> str:
    """Resolve the stage config file path from the model name.

    Resolves stage configuration path based on the model type and device type.
    First tries to find a device-specific YAML file from stage_configs/{device_type}/
    directory. If not found, falls back to the default config file.

    Args:
        model: Model name or path (used to determine model_type)

    Returns:
        String path to the stage configuration file

    Raises:
        ValueError: If model_type cannot be determined
        FileNotFoundError: If no stage config file exists for the model type
    """
    # Try to get config from standard transformers format first
    try:
        hf_config = get_config(model, trust_remote_code=True)
        model_type = hf_config.model_type
    except (ValueError, Exception):
        # If standard transformers format fails, try diffusers format
        if file_or_path_exists(model, "model_index.json", revision=None):
            model_type = _try_get_class_name_from_diffusers_config(model)
            if model_type is None:
                raise ValueError(
                    f"Could not determine model_type for diffusers model: {model}. "
                    f"Please ensure the model has 'model_type' in transformer/config.json or model_index.json"
                )
        elif file_or_path_exists(model, "config.json", revision=None):
            # Try to read config.json manually for custom models like Bagel that fail get_config
            # but have a valid config.json with model_type
            try:
                config_dict = get_hf_file_to_dict("config.json", model, revision=None)
                if config_dict and "model_type" in config_dict:
                    model_type = config_dict["model_type"]
                else:
                    raise ValueError(f"config.json found but missing 'model_type' for model: {model}")
            except Exception as e:
                raise ValueError(f"Failed to read config.json for model: {model}. Error: {e}") from e
        else:
            raise ValueError(
                f"Could not determine model_type for model: {model}. "
                f"Model is not in standard transformers format and does not have model_index.json. "
                f"Please ensure the model has proper configuration files with 'model_type' field"
            )

    default_config_path = current_omni_platform.get_default_stage_config_path()
    model_type_str = f"{model_type}.yaml"
    complete_config_path = PROJECT_ROOT / default_config_path / model_type_str
    if os.path.exists(complete_config_path):
        return str(complete_config_path)

    # Fall back to default config
    stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml"
    stage_config_path = PROJECT_ROOT / stage_config_file
    if not os.path.exists(stage_config_path):
        return None
    return str(stage_config_path)


def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list:
    """Load stage configurations from model's default config file.

    Loads stage configurations based on the model type and device type.
    First tries to load a device-specific YAML file from stage_configs/{device_type}/
    directory. If not found, falls back to the default config file.

    Args:
        model: Model name or path (used to determine model_type)

    Returns:
        List of stage configuration dictionaries

    Raises:
        FileNotFoundError: If no stage config file exists for the model type
    """
    if base_engine_args is None:
        base_engine_args = {}
    stage_config_path = resolve_model_config_path(model)
    if stage_config_path is None:
        return []
    stage_configs = load_stage_configs_from_yaml(config_path=stage_config_path, base_engine_args=base_engine_args)
    return stage_configs


def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None = None) -> list:
    """Load stage configurations from a YAML file.

    Args:
        config_path: Path to the YAML configuration file

    Returns:
        List of stage configuration dictionaries from the file's stage_args
    """
    if base_engine_args is None:
        base_engine_args = {}
    config_data = OmegaConf.load(config_path)
    stage_args = config_data.stage_args
    global_async_chunk = config_data.get("async_chunk", False)
    # Convert any nested dataclass objects to dicts before creating OmegaConf
    base_engine_args = _convert_dataclasses_to_dict(base_engine_args)
    base_engine_args = OmegaConf.create(base_engine_args)
    for stage_arg in stage_args:
        base_engine_args_tmp = base_engine_args.copy()
        # Update base_engine_args with stage-specific engine_args if they exist
        if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None:
            base_engine_args_tmp = OmegaConf.merge(base_engine_args_tmp, stage_arg.engine_args)
        stage_type = getattr(stage_arg, "stage_type", "llm")
        if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None and stage_type != "diffusion":
            runtime_cfg = stage_arg.runtime
            max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1)
            base_engine_args_tmp["max_num_seqs"] = max_batch_size
            base_engine_args_tmp.async_chunk = global_async_chunk
        stage_arg.engine_args = base_engine_args_tmp
    return stage_args


def get_final_stage_id_for_e2e(
    output_modalities: list[str] | None, default_modalities: list[str], stage_list: list
) -> int:
    """Get the final stage id for e2e.

    Args:
        stage_list: List of stage configurations

    Returns:
        Final stage id for e2e
    """
    last_stage_id = len(stage_list) - 1
    if output_modalities is not None:
        prompt_modalities = []
        for modality in output_modalities:
            if modality not in default_modalities:
                logger.warning(f"Invalid output modality: {modality}, ignoring it")
                # TODO: if user specifies unsupported modalities, invalid it and raise an error
                continue
            prompt_modalities.append(modality)
        output_modalities = prompt_modalities
    else:
        output_modalities = default_modalities

    try:
        for _sid in range(last_stage_id, -1, -1):
            if (
                getattr(stage_list[_sid], "final_output", False)
                and stage_list[_sid].final_output_type in output_modalities
            ):
                final_stage_id_for_e2e = _sid
                break
        if final_stage_id_for_e2e < 0:
            final_stage_id_for_e2e = last_stage_id
    except Exception as e:
        logger.debug(
            "[Orchestrator] Failed to determine final stage for E2E; \
                falling back to last: %s",
            e,
            exc_info=True,
        )
        final_stage_id_for_e2e = last_stage_id

    return final_stage_id_for_e2e