# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Access the in-process vLLM model weights for compactor weight sharing.""" from __future__ import annotations import torch.nn as nn from vllm.logger import init_logger logger = init_logger(__name__) def extract_vllm_causal_lm(llm: object) -> nn.Module: """Return the root ``nn.Module`` holding transformer + lm_head from a v1 ``LLM``. Requires ``LLMEngine`` to have been constructed with ``multiprocess_mode=False`` so ``model_executor`` lives in-process (set ``VLLM_ENABLE_V1_MULTIPROCESSING=0``). """ llm_engine = getattr(llm, "llm_engine", None) if llm_engine is None: raise RuntimeError("Expected an object with a ``llm_engine`` attribute (e.g. ``vllm.LLM``).") ex = getattr(llm_engine, "model_executor", None) if ex is None: raise RuntimeError( "model_executor is unavailable (multiprocess engine mode). " "Set environment variable VLLM_ENABLE_V1_MULTIPROCESSING=0 for " "in-process weight sharing." ) driver = getattr(ex, "driver_worker", None) if driver is None: raise RuntimeError( "Executor has no driver_worker (unexpected executor type for weight sharing)." ) worker = getattr(driver, "worker", None) if worker is None: raise RuntimeError("Worker wrapper has no worker loaded.") get_model = getattr(worker, "get_model", None) if not callable(get_model): raise RuntimeError("Worker does not expose get_model().") return get_model()