Unverified Commit d5ec6c05 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[UX] Add vLLM model inspection view (#29450)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 08d954f0
...@@ -348,6 +348,9 @@ class LLM: ...@@ -348,6 +348,9 @@ class LLM:
self.input_processor = self.llm_engine.input_processor self.input_processor = self.llm_engine.input_processor
self.io_processor = self.llm_engine.io_processor self.io_processor = self.llm_engine.io_processor
# Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer() return self.llm_engine.get_tokenizer()
...@@ -1786,3 +1789,16 @@ class LLM: ...@@ -1786,3 +1789,16 @@ class LLM:
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
def __repr__(self) -> str:
"""Return a transformers-style hierarchical view of the model."""
# Cache the result to avoid repeated collective_rpc calls
if self._cached_repr is None:
results = self.llm_engine.collective_rpc("get_model_inspection")
# In distributed settings, we get results from all workers
# Just return the first one (they should all be the same)
if results:
self._cached_repr = results[0]
else:
self._cached_repr = f"LLM(model={self.model_config.model!r})"
return self._cached_repr
...@@ -250,6 +250,7 @@ if TYPE_CHECKING: ...@@ -250,6 +250,7 @@ if TYPE_CHECKING:
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_LOG_MODEL_INSPECTION: bool = False
VLLM_DEBUG_MFU_METRICS: bool = False VLLM_DEBUG_MFU_METRICS: bool = False
...@@ -1595,6 +1596,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1595,6 +1596,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool( "VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
), ),
# Log model inspection after loading.
# If enabled, logs a transformers-style hierarchical view of the model
# with quantization methods and attention backends.
"VLLM_LOG_MODEL_INSPECTION": lambda: bool(
int(os.getenv("VLLM_LOG_MODEL_INSPECTION", "0"))
),
# Debug logging for --enable-mfu-metrics # Debug logging for --enable-mfu-metrics
"VLLM_DEBUG_MFU_METRICS": lambda: bool( "VLLM_DEBUG_MFU_METRICS": lambda: bool(
int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0")) int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0"))
......
...@@ -285,5 +285,5 @@ class ApplyRotaryEmb(CustomOp): ...@@ -285,5 +285,5 @@ class ApplyRotaryEmb(CustomOp):
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"is_neox_style={self.is_neox_style}" s = f"is_neox_style={self.is_neox_style}"
s += f"enable_fp32_compute={self.enable_fp32_compute}" s += f", enable_fp32_compute={self.enable_fp32_compute}"
return s return s
...@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -50,8 +51,21 @@ class BaseModelLoader(ABC): ...@@ -50,8 +51,21 @@ class BaseModelLoader(ABC):
vllm_config=vllm_config, model_config=model_config vllm_config=vllm_config, model_config=model_config
) )
log_model_inspection(model)
logger.debug("Loading weights on %s ...", load_device) logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it # Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config) self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device) process_weights_after_loading(model, model_config, target_device)
return model.eval() return model.eval()
def log_model_inspection(model: nn.Module) -> None:
"""Log model structure if VLLM_LOG_MODEL_INSPECTION=1."""
if not envs.VLLM_LOG_MODEL_INSPECTION:
return
from vllm.model_inspection import format_model_inspection
logger.info("vLLM model structure:\n%s", format_model_inspection(model))
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Model inspection utilities for vLLM."""
import torch.nn as nn
def _get_module_info(module: nn.Module) -> str:
"""Get info string for a module."""
class_name = type(module).__name__
parts = []
# Add quant_method if present
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_name = type(quant_method).__name__
# For CompressedTensors, show the underlying scheme instead
scheme = getattr(module, "scheme", None)
if scheme is not None:
quant_name = type(scheme).__name__
# Skip unquantized methods
if "Unquantized" not in quant_name:
parts.append(f"quant={quant_name}")
# If module has extra_repr, use it
if hasattr(module, "extra_repr"):
parts.append(module.extra_repr().replace("\n", ""))
if parts:
return f"{class_name}({', '.join(parts)})"
# For unknown modules, use the default PyTorch repr
return str(module)
def _get_child_signature(child: nn.Module) -> str:
"""Get a signature for a child module to detect duplicates."""
lines = []
for name, submodule in child.named_modules():
lines.append(f"{name}:{_get_module_info(submodule)}")
return "\n".join(lines)
def _format_index_ranges(indices: list[int]) -> str:
"""Format indices into range notation (e.g., [0,1,2,4,5,6] -> '0-2, 4-6')."""
indices = sorted(indices)
ranges = []
start = end = indices[0]
for idx in indices[1:]:
if idx == end + 1:
end = idx
else:
ranges.append(str(start) if start == end else f"{start}-{end}")
start = end = idx
ranges.append(str(start) if start == end else f"{start}-{end}")
return ", ".join(ranges)
def _format_module_tree(
module: nn.Module,
name: str = "",
indent: int = 0,
) -> list[str]:
"""Format a module tree with indentation, grouping identical layers.
Produces output like:
(layers): ModuleList(
(0-27, 29-47): 47 x LlamaDecoderLayer(
...
)
(28, 48): 2 x DifferentDecoderLayer(
...
)
)
"""
lines = []
prefix = " " * indent
children = list(module.named_children())
# Leaf node - just output the module info
if not children:
info = _get_module_info(module)
lines.append(f"{prefix}({name}): {info}" if name else f"{prefix}{info}")
return lines
# Non-leaf node - output opening line and recurse into children
info = _get_module_info(module)
lines.append(f"{prefix}({name}): {info}(" if name else f"{prefix}{info}(")
# Separate numbered children (e.g., "0", "1") from named ones (e.g., "norm")
numbered: list[tuple[int, nn.Module]] = []
non_numbered: list[tuple[str, nn.Module]] = []
for child_name, child_module in children:
try:
numbered.append((int(child_name), child_module))
except ValueError:
non_numbered.append((child_name, child_module))
# Group numbered children by structure signature to collapse identical layers
# e.g., layers 0-27 and 29-47 with same structure become "(0-27, 29-47): 47 x"
if numbered:
sig_to_group: dict[str, list[tuple[int, nn.Module]]] = {}
for idx, child_module in numbered:
sig = _get_child_signature(child_module)
sig_to_group.setdefault(sig, []).append((idx, child_module))
# Output groups sorted by first index
for group in sorted(sig_to_group.values(), key=lambda g: g[0][0]):
indices = [idx for idx, _ in group]
representative = group[0][1]
child_lines = _format_module_tree(representative, "", indent + 1)
first_line = child_lines[0].lstrip()
child_prefix = " " * (indent + 1)
if len(indices) > 1:
range_str = _format_index_ranges(indices)
child_lines[0] = (
f"{child_prefix}({range_str}): {len(indices)} x {first_line}"
)
else:
child_lines[0] = f"{child_prefix}({indices[0]}): {first_line}"
lines.extend(child_lines)
# Output non-numbered children (e.g., "embed_tokens", "norm")
for child_name, child_module in non_numbered:
lines.extend(_format_module_tree(child_module, child_name, indent + 1))
lines.append(f"{prefix})")
return lines
def format_model_inspection(model: nn.Module) -> str:
"""Format a model into a transformers-style hierarchical string."""
return "\n".join(_format_module_tree(model))
...@@ -118,6 +118,12 @@ class WorkerBase: ...@@ -118,6 +118,12 @@ class WorkerBase:
"""Apply a function on the model inside this worker.""" """Apply a function on the model inside this worker."""
return fn(self.get_model()) return fn(self.get_model())
def get_model_inspection(self) -> str:
"""Return a transformers-style hierarchical view of the model."""
from vllm.model_inspection import format_model_inspection
return format_model_inspection(self.get_model())
def load_model(self) -> None: def load_model(self) -> None:
"""Load model onto target device.""" """Load model onto target device."""
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment