Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
# Documentation Build Guide
This directory contains the source files for the vLLM-Omni documentation.
## Building Documentation Locally
### Prerequisites
Install documentation dependencies:
```bash
uv pip install -e ".[docs]"
```
### Build and Serve Documentation
From the project root:
```bash
# Serve documentation locally (auto-reload on changes)
# This starts a local web server at http://127.0.0.1:8000
mkdocs serve
# Build static site (generates HTML files in site/ directory)
mkdocs build
```
When using `mkdocs serve`, the documentation will be automatically available at `http://127.0.0.1:8000`. The server will automatically reload when you make changes to the documentation files.
## Auto-generating API Documentation
The documentation automatically extracts docstrings from the code using mkdocstrings. To ensure your code is documented:
1. Add docstrings to all public classes, functions, and methods
2. Use Google or NumPy style docstrings (both are supported)
3. Rebuild the documentation to see changes
Example docstring:
```python
class Omni:
"""Main entry point for vLLM-Omni inference.
This class provides a high-level interface for running multi-modal
inference with non-autoregressive models.
Args:
model: Model name or path
stage_configs: Optional stage configurations
**kwargs: Additional arguments passed to the engine
Example:
>>> llm = Omni(model="Qwen/Qwen2.5-Omni")
>>> outputs = llm.generate(prompts="Hello")
"""
```
## Documentation Structure
```
docs/
├── index.md # Main documentation page
├── getting_started/ # Getting started guides
├── architecture/ # Architecture documentation
├── api/ # API reference (auto-generated from code)
├── examples/ # Code examples
└── stylesheets/ # Custom CSS
```
## Publishing Documentation
### GitHub Pages (Recommended)
The documentation is automatically deployed to GitHub Pages using GitHub Actions.
1. **Enable GitHub Pages**:
- Go to repository `Settings``Pages`
- Set `Source` to `GitHub Actions`
- Save settings
2. **Push changes**:
```bash
git push origin main
```
3. **Documentation will be available at**:
- `https://vllm-omni.readthedocs.io`
The GitHub Actions workflow (`.github/workflows/docs.yml`) will automatically:
- Build the documentation when you push to `main` branch
- Deploy it to GitHub Pages
- Update the documentation whenever you make changes
### Read the Docs (Alternative)
You can also use Read the Docs for hosting:
1. Sign up at https://readthedocs.org/
2. Import the `vllm-project/vllm-omni` repository
3. Read the Docs will automatically build using `.readthedocs.yml`
4. Documentation will be available at: `https://vllm-omni.readthedocs.io/`
## Configuration
The documentation configuration is in `mkdocs.yml` at the project root.
## Tips
- **API Documentation**: API docs are automatically generated using `mkdocs-api-autonav` and `mkdocstrings`
- No need to manually create API pages - they're generated automatically
- Use `[module.name.ClassName][]` syntax for cross-references in Summary pages
- **Code Snippets**: Use `--8<-- "path/to/file.py"` for including code snippets
- **Markdown**: Use Markdown for all documentation (no need for RST)
- **Material Theme**: Use Material theme features like:
- Admonitions: `!!! note`, `!!! warning`, etc.
- Code blocks with syntax highlighting
- Tabs for organizing content
- Math formulas using `pymdownx.arithmatex`
## Troubleshooting
### Documentation not updating
- Make sure you've saved all files
- If using `mkdocs serve`, it should auto-reload
- Check for syntax errors in `mkdocs.yml`
### API links not working
- Ensure class names match exactly (case-sensitive)
- Check that the module is imported correctly
- Run `mkdocs build --strict` to check for errors
### Build errors
- Check Python version (requires 3.9+)
- Ensure all dependencies are installed: `pip install -e ".[docs]"`
- Check `mkdocs.yml` syntax with `mkdocs build --strict`
# Contributing to vLLM-Omni
Thank you for your interest in contributing to vLLM-Omni! This document provides guidelines and instructions for contributing.
!!! note
We host weekly developer-facing online meetings to discuss milestones and updates **every Tuesday at 19:30 PDT**. Meeting link as well as the past meeting notes can be found [here](https://tinyurl.com/vllm-omni-meeting).
## Getting Started
vLLM-Omni uses `uv` as the environment manager, to create and manage Python environments. Please follow the documentation to install `uv`. After installing `uv`, you can create a new Python environment using the following commands:
```bash
uv venv --python 3.12 --seed
source .venv/bin/activate
```
### Development Environment for vLLM and vLLM-Omni
vLLM-Omni is quickly evolving, please see the [installation guide](../getting_started/installation/README.md) for details. It's recommended to build from source to provide the latest development environment.
!!! tip
vLLM-Omni is compatible with Python versions 3.10 to 3.12. However, we recommend developing with Python 3.12 to minimize the chance of your local environment clashing with our CI environment.
### Adding a new model to vLLM-Omni
Please check [model implementation](model/README.md) for how to add diffusion and omni-modality models to vLLM-Omni.
### Linting
vLLM-Omni uses `pre-commit` to lint and format the codebase. See [pre-commit documentation](https://pre-commit.com/#usage) if `pre-commit` is new to you. Setting up `pre-commit` is as easy as:
```bash
uv pip install pre-commit
pre-commit install
```
vLLM-Omni's `pre-commit` hooks will now run automatically every time you commit.
!!! tip
You can manually run the `pre-commit` hooks using:
```bash
pre-commit run # runs on staged files
pre-commit run --show-diff-on-failure --color=always --all-files # runs on all files (short for --all-files)
```
### Documentation
MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, `mkdocs.yml`.
Get started with:
```bash
uv pip install -e ".[docs]"
```
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. From the root of the repository, run:
```bash
mkdocs serve # with API ref (~10 minutes)
API_AUTONAV_EXCLUDE=vllm_omni mkdocs serve # API ref off (~15 seconds)
```
Once you see `Serving on http://127.0.0.1:8000/` in the logs, the live preview is ready! Open <http://127.0.0.1:8000/> in your browser to see it.
For additional features and advanced configurations, refer to the:
- [MkDocs documentation](https://www.mkdocs.org/)
- [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/) (the MkDocs theme we use)
### Testing
vLLM-Omni uses `pytest` to test the codebase.
```bash
# Run all tests
pytest tests/
# Run tests for a single test file with detailed output
pytest -s -v tests/test_omni_llm.py
```
!!! warning
Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU platform to run unit tests locally, rely on the continuous integration system to run the tests for now.
## Issues
If you encounter a bug or have a feature request, please search existing issues first to see if it has already been reported. If not, please file a new issue, providing as much relevant information as possible.
!!! important
If you discover a security vulnerability, please report it by creating a GitHub issue with the `security` label.
## Pull Requests & Code Reviews
Thank you for your contribution to vLLM-Omni! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM-Omni maintain the code quality and improve the efficiency of the review process.
### DCO and Signed-off-by
When contributing changes to this project, you must agree to the [DCO](https://developercertificate.org/). Commits must include a `Signed-off-by:` header which certifies agreement with the terms of the DCO.
Using `-s` with `git commit` will automatically add this header.
!!! tip
You can enable automatic sign-off via your IDE:
- **PyCharm**: Click on the `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window. It will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`.
- **VSCode**: Open the Settings editor and enable the `Git: Always Sign Off` (`git.alwaysSignOff`) field.
### PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
- `[Bugfix]` for bug fixes.
- `[CI/Build]` for build or continuous integration improvements.
- `[Doc]` for documentation fixes and improvements.
- `[Model]` for adding a new model or improving an existing model. Model name should appear in the title.
- `[Frontend]` For changes on the vLLM-Omni frontend (e.g., OpenAI API server, `OmniLLM` class, etc.)
- `[Kernel]` for changes affecting CUDA kernels or other compute kernels.
- `[Core]` for changes in the core vLLM-Omni logic (e.g., `OmniProcessor`, `OmniARScheduler`, etc.)
- `[Hardware][Vendor]` for hardware-specific changes. Vendor name should appear in the prefix, such as [Ascend] for Ascend NPUs.
- `[Misc]` for PRs that do not fit the above categories. Please use this sparingly.
!!! note
If the PR spans more than one category, please include all relevant prefixes.
### Code Quality
The PR needs to meet the following code quality standards:
- We adhere to Google Python style guide and Google C++ style guide.
- Pass all linter checks.
- The code needs to be well-documented to ensure future contributors can easily understand the code.
- Include sufficient tests to ensure the project stays correct and robust. This includes both unit tests and integration tests.
- Please add documentation to `docs/` if the PR modifies the user-facing behaviors of vLLM-Omni. It helps vLLM-Omni users understand and utilize the new features or changes.
### Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with `rfc-required` and might not go through the PR.
### What to Expect for the Reviews
The goal of the vLLM-Omni team is to be a _transparent reviewing machine_. We would like to make the review process transparent and efficient and make sure no contributor feels confused or frustrated. However, the vLLM-Omni team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- After the PR is assigned, the reviewer will provide status updates every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM-Omni team.
- After the review, the reviewer will put an `action-required` label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
- Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
## Additional Resources
- [Design Documents](../design/index.md) - Architecture and design documentation
## Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM-Omni. All of your contributions help make vLLM-Omni a great tool and community for everyone!
# CI Failures
What should I do when a CI job fails on my PR, but I don't think my PR caused
the failure?
# Markers for Tests
By adding markers before test functions, tests can later be executed uniformly by simply declaring the corresponding marker type.
## Current Markers
Defined in `pyproject.toml`:
| Marker | Description |
| ------------------ | ------------------------------------------------------- |
| `core_model` | Core model tests (run in each PR) |
| `diffusion` | Diffusion model tests |
| `omni` | Omni model tests |
| `cache` | Cache backend tests |
| `parallel` | Parallelism/distributed tests |
| `cpu` | Tests that run on CPU |
| `gpu` | Tests that run on GPU (auto-added) |
| `cuda` | Tests that run on CUDA (auto-added) |
| `rocm` | Tests that run on AMD/ROCm (auto-added) |
| `npu` | Tests that run on NPU/Ascend (auto-added) |
| `H100` | Tests that require H100 GPU |
| `L4` | Tests that require L4 GPU |
| `MI325` | Tests that require MI325 GPU (AMD/ROCm) |
| `A2` | Tests that require A2 NPU |
| `A3` | Tests that require A3 NPU |
| `distributed_cuda` | Tests that require multi cards on CUDA platform |
| `distributed_rocm` | Tests that require multi cards on ROCm platform |
| `distributed_npu` | Tests that require multi cards on NPU platform |
| `skipif_cuda` | Skip if the num of CUDA cards is less than the required |
| `skipif_rocm` | Skip if the num of ROCm cards is less than the required |
| `skipif_npu` | Skip if the num of NPU cards is less than the required |
| `slow` | Slow tests (may skip in quick CI) |
| `benchmark` | Benchmark tests |
For those markers shown as auto-added, they will be added by the `@hardware_test` decorator.
### Example usage for markers
```python
from tests.utils import hardware_test
@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(
res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
num_cards=2,
)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_video_to_audio()
...
```
### Decorator: `@hardware_test`
This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/utils.py` performs the following actions:
1. **Applies platform and resource markers**
Adds the appropriate pytest markers for each specified hardware platform (e.g., `cuda`, `rocm`, `npu`) and resource type (e.g., `L4`, `H100`, `MI325`, `A2`, `A3`).
```
@pytest.mark.cuda
@pytest.mark.L4
```
2. **Handles multi-card (distributed) scenarios**
For tests requiring multiple cards, it automatically adds distributed markers such as `distributed_cuda`, `distributed_rocm`, or `distributed_npu`.
```
@pytest.mark.distributed_cuda(num_cards=num_cards)
```
3. **Supports flexible card requirements**
Accepts `num_cards` as either a single integer for all platforms or as a dictionary with per-platform values. If not specified, defaults to 1 card per platform.
4. **Integrates resource validation**
On CUDA, adds a skip marker (`skipif_cuda`) if the system does not have the required number of devices.
Support for `skipif_rocm` and `skipif_npu` will be implemented later.
5. **Runs each test in a new process**
Automatically wraps the distributed test with a decorator (`@create_new_process_for_each_test`) to ensure isolation and compatibility with multi-process hardware backends.
6. **Works with pytest filtering**
Allows tests to be filtered and selected at runtime using standard pytest marker expressions (e.g., `-m "distributed_cuda and L4"`).
#### Example usage for decorator
- Single call for multiple platforms:
```python
@hardware_test(
res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
num_cards={"cuda": 2, "rocm": 2, "npu": 2},
)
```
or
```python
@hardware_test(
res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
num_cards=2,
)
```
- `res` must be a dict; supported resources: CUDA (L4/H100), ROCm (MI325), NPU (A2/A3)
- `num_cards` can be int (all platforms) or dict (per platform); defaults to 1 when missing
- `hardware_test` automatically applies `@create_new_process_for_each_test` for distributed tests.
- Distributed markers (`distributed_cuda`, `distributed_rocm`, `distributed_npu`) are auto-added for multi-card cases
- Filtering examples:
- CUDA only: `pytest -m "distributed_cuda and L4"`
- ROCm only: `pytest -m "distributed_rocm and MI325"`
- NPU only: `pytest -m "distributed_npu"`
## Add Support for a New Platform
If you want to add support for a new platform (e.g., "tpu" for a new accelerator), follow these steps:
1. **Extend the marker list in your pytest config** so that platform/resource markers are defined:
```toml
# In pyproject.toml or pytest.ini
[tool.pytest.ini_options]
markers = [
# ... existing markers ...
"tpu: Tests that require TPU device",
"TPU_V3: Tests that require TPU v3 hardware",
"distributed_tpu: Tests that require multiple TPU devices",
]
```
2. **Implement a marker construction function for your platform** in `vllm-omni/tests/utils.py`:
```python
# In vllm-omni/tests/utils.py
def tpu_marks(*, res: str, num_cards: int):
test_platform = pytest.mark.tpu
if res == "TPU_V3":
test_resource = pytest.mark.TPU_V3
else:
raise ValueError(
f"Invalid TPU resource type: {res}. Supported: TPU_V3")
if num_cards == 1:
return [test_platform, test_resource]
else:
test_distributed = pytest.mark.distributed_tpu(num_cards=num_cards)
# Optionally: add skipif_tpu when implemented
return [test_platform, test_resource, test_distributed]
```
3. **Update `hardware_test` to recognize your new platform**:
In the relevant place (see the `hardware_test` implementation), add:
```python
if platform == "tpu":
marks = tpu_marks(res=resource, num_cards=cards)
```
4. **(Recommended) Add a test using your new markers**:
```python
@hardware_test(
res={"tpu": "TPU_V3"},
num_cards=2,
)
def test_my_tpu_feature():
...
```
**Summary**:
- Add pytest markers for your new platform/resources
- Implement a marker function (`xxx_marks`)
- Plug into `hardware_test`
- You're done: tests decorated with `@hardware_test` using your platform now automatically get the correct markers, distribution, and isolation!
See code in `vllm-omni/tests/utils.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`).
# Test File Structure and Style Guide
To ensure project maintainability and sustainable development, we encourage contributors to submit test code (unit tests, system tests, or end-to-end tests) alongside their code changes. This document outlines the guidelines for organizing and naming test files.
## Test Types
### Unit Tests and System Tests
For unit tests and system tests, we strongly recommend placing test files in the same directory structure as the source code being tested, using the naming convention `test_*.py`.
### End-to-End (E2E) Tests for Models
End-to-end tests verify the complete functionality of a system or component. For our project, the E2E tests for different omni models are organized into two subdirectories:
- **`tests/e2e/offline_inference/`**: Tests for offline inference modes (e.g., Qwen3Omni offline inference)
- **`tests/e2e/online_serving/`**: Tests for online serving scenarios (e.g., API server tests)
**Example:** The test file for `vllm_omni/entrypoints/omni_llm.py` should be located at `tests/entrypoints/test_omni_llm.py`.
## Test Directory Structure
The ideal directory structure mirrors the source code organization:
```
vllm_omni/ tests/
├── config/ → ├── config/
│ └── model.py │ └── test_model.py
├── core/ → ├── core/
│ └── sched/ │ └── sched/ # Maps to core/sched/
│ ├── omni_ar_scheduler.py │ ├── test_omni_ar_scheduler.py
│ ├── omni_generation_scheduler.py │ ├── test_omni_generation_scheduler.py
│ └── output.py │ └── test_output.py
├── diffusion/ → ├── diffusion/
│ ├── diffusion_engine.py │ ├── test_diffusion_engine.py
│ ├── omni_diffusion.py │ ├── test_omni_diffusion.py
│ ├── attention/ │ ├── attention/ # Maps to diffusion/attention/
│ │ └── backends/ │ │ └── test_*.py
│ ├── models/ │ ├── models/ # Maps to diffusion/models/
│ │ ├── qwen_image/ │ │ ├── qwen_image/
│ │ │ └── ... │ │ │ └── test_*.py
│ │ └── z_image/ │ │ └── z_image/
│ │ └── ... │ │ └── test_*.py
│ └── worker/ │ └── worker/ # Maps to diffusion/worker/
│ └── ... │ └── test_*.py
├── distributed/ → ├── distributed/
│ └── ... │ └── test_*.py
├── engine/ → ├── engine/
│ ├── processor.py │ ├── test_processor.py
│ └── output_processor.py │ └── test_output_processor.py
├── entrypoints/ → ├── entrypoints/
│ ├── omni_llm.py │ ├── test_omni_llm.py # UT: OmniLLM core logic (mocked)
│ ├── omni_stage.py │ ├── test_omni_stage.py # UT: OmniStage logic
│ ├── omni.py │ ├── test_omni.py # E2E: Omni class (offline inference)
│ ├── async_omni.py │ ├── test_async_omni.py # E2E: AsyncOmni class
│ ├── cli/ │ ├── cli/ # Maps to entrypoints/cli/
│ │ └── ... │ │ └── test_*.py
│ └── openai/ │ └── openai/ # Maps to entrypoints/openai/
│ ├── api_server.py │ ├── test_api_server.py # E2E: API server (online serving)
│ └── serving_chat.py │ └── test_serving_chat.py
├── inputs/ → ├── inputs/
│ ├── data.py │ ├── test_data.py
│ ├── parse.py │ ├── test_parse.py
│ └── preprocess.py │ └── test_preprocess.py
├── model_executor/ → ├── model_executor/
│ ├── layers/ │ ├── layers/
│ │ └── mrope.py │ │ └── test_mrope.py
│ ├── model_loader/ │ ├── model_loader/
│ │ └── weight_utils.py │ │ └── test_weight_utils.py
│ ├── models/ │ ├── models/
│ │ ├── qwen2_5_omni/ │ │ ├── qwen2_5_omni/
│ │ │ ├── qwen2_5_omni_thinker.py │ │ │ ├── test_qwen2_5_omni_thinker.py # UT
│ │ │ ├── qwen2_5_omni_talker.py │ │ │ ├── test_qwen2_5_omni_talker.py # UT
│ │ │ └── qwen2_5_omni_token2wav.py │ │ │ └── test_qwen2_5_omni_token2wav.py # UT
│ │ └── qwen3_omni/ │ │ └── qwen3_omni/
│ │ └── ... │ │ └── test_*.py
│ ├── stage_configs/ │ └── stage_configs/ # Configuration tests (if needed)
│ │ └── ... │ └── test_*.py
│ └── stage_input_processors/ │ └── stage_input_processors/
│ └── ... │ └── test_*.py
├── sample/ → ├── sample/
│ └── ... │ └── test_*.py
├── utils/ → ├── utils/
│ └── platform_utils.py │ └── test_platform_utils.py
├── worker/ → ├── worker/
├── gpu_ar_worker.py │ ├── test_gpu_ar_worker.py
├── gpu_generation_worker.py │ ├── test_gpu_generation_worker.py
├── gpu_model_runner.py │ ├── test_gpu_model_runner.py
└── npu/ │ └── npu/ # Maps to worker/npu/
└── ... │ └── test_*.py
└── e2e/ → ├── e2e/ # End-to-end scenarios (no 1:1 source mirror)
├── online_serving/ # Full-stack online serving flows
│ └── (empty for now)
└── offline_inference/ # Full offline inference flows
├── test_qwen2_5_omni.py # Moved from multi_stages/
├── test_qwen3_omni.py # Moved from multi_stages_h100/
├── test_t2i_model.py # Moved from single_stage/
└── stage_configs/ # Shared stage configs
├── qwen2_5_omni_ci.yaml
└── qwen3_omni_ci.yaml
```
### Naming Conventions
- **Unit Tests**: Use `test_<module_name>.py` format. Example: `omni_llm.py``test_omni_llm.py`
- **E2E Tests**: Place in `tests/e2e/offline_inference/` or `tests/e2e/online_serving/` with descriptive names. Example: `tests/e2e/offline_inference/test_qwen3_omni.py`, `tests/e2e/offline_inference/test_diffusion_model.py`
### Best Practices
1. **Mirror Source Structure**: Test directories should mirror the source code structure
2. **Test Type Indicators**: Use comments to indicate test types (UT for unit tests, E2E for end-to-end tests)
3. **Shared Resources**: Place shared test configurations (e.g., CI configs) in appropriate subdirectories
4. **Consistent Naming**: Follow the `test_*.py` naming convention consistently across all test files
## Test codes requirements
### Coding style
1. **File header**: Add SPDX license header to all test files
2. **Imports**: Pls don't use manual `sys.path` modifications, use standard imports instead.
3. **Test type differentiation**:
- Unit tests: Maintain mock style
- E2E tests for models: Consider using OmniRunner uniformly, avoid decorators
4. **Documentation**: Add docstrings to all test functions
5. **Environment variables**: Set uniformly in `conftest.py` or at the top of files
6. **Type annotations**: Add type annotations to all test function parameters
7. **Pytest Markers**: Add necessary markers like `@pytest.mark.core_model` and use `@hardware_test` to declare hardware requirements (check detailed in [Markers for Tests](../ci/tests_markers.md)).
### Template
#### E2E - Online serving
```python
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Online E2E smoke test for an omni model (video,text,audio → audio).
"""
from pathlib import Path
import pytest
import openai
from tests.utils import hardware_test
# Optional: set process start method for workers
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
models = ["{your model name}"] #Edit here to load your model
stage_configs = [str(Path(__file__).parent / "stage_configs" / {your model yaml})] #Edit here to load your model yaml
test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
#OmniServer,Used to start the vllm-omni server
class OmniServer:
xxx
@pytest.fixture
def omni_server(request):
model, stage_config_path = request.param
with OmniServer(model, ["--stage-configs-path", stage_config_path]) as server:
yield server
#handle request message
@pytest.fixture(scope="session")
def base64_encoded_video() -> str:
xxx
@pytest.fixture(scope="session")
def dummy_messages_from_video_data(video_data_url: str, content_text: str) -> str:
xxx
@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(
res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
num_cards={"cuda": 2, "rocm": 2, "npu": 4},
)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_video_to_audio(
client: openai.OpenAI,
omni_server,
base64_encoded_video: str,
) -> None:
#set message
video_data_url = f"data:video/mp4;base64, {base64_encoded_video}"
messages = dummy_messages_from_video_data(video_data_url)
#send request
chat_completion = client.chat.completions.create(
model=omni_server.model,
messages=messages,
)
#verify text output
text_choice = chat_completion.choices[0]
assert text_choice.finish_reason == "length"
#verify audio output
audio_choice = chat_completion.choices[1]
audio_message = audio_choice.message
if hasattr(audio_message, "audio") and audio_message.audio:
assert audio_message.audio.data is not None
assert len(audio_message.audio.data) > 0
```
#### E2E - Offline inference
```python
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Offline E2E smoke test for an omni model (video → audio).
"""
import os
from pathlib import Path
import pytest
from vllm.assets.video import VideoAsset
from tests.utils import hardware_test
from ..multi_stages.conftest import OmniRunner
# Optional: set process start method for workers
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
models = ["{your model name}"] #Edit here to load your model
stage_configs = [str(Path(__file__).parent / "stage_configs" / {your model yaml})] #Edit here to load your model yaml
# Create parameter combinations for model and stage config
test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
# function name: test_{input_modality}_to_{output_modality}
# modality candidate: text, image, audio, video, mixed_modalities
@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(
res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
num_cards=2,
)
@pytest.mark.parametrize("test_config", test_params)
def test_video_to_audio(omni_runner: type[OmniRunner], model: str) -> None:
"""Offline inference: video input, audio output."""
model, stage_config_path = test_config
with omni_runner(model, seed=42, stage_configs_path=stage_config_path) as runner:
# Prepare inputs
video = VideoAsset(name="sample", num_frames=4).np_ndarrays
outputs = runner.generate_multimodal(
prompts="Describe this video briefly.",
videos=video,
)
# Minimal assertions: got outputs and at least one audio result
assert outputs
has_audio = any(o.final_output_type == "audio" for o in outputs)
assert has_audio
```
## Checklist before submitting your test files
1. The file is saved in an appropriate place and the file name is clear.
2. The coding style follows the requirements outlined above.
3. **All test functions have appropriate pytest markers**
4. For tests that need run in CI, please ensure the test is configured under the `./buildkite/` folder.
# Adding a New Model
This section provides comprehensive guidance on how to add a new model to vLLM-Omni.
## Documentation
- **[Adding an Omni-Modality Model](adding_omni_model.md)**: Complete step-by-step guide using Qwen3-Omni as an example.
- **[Adding a Diffusion Model](adding_diffusion_model.md)**: Complete step-by-step guide using Qwen/Qwen-Image-Edit as an example.
## Quick Start
For a quick reference, see the [Adding a New Multi-Stage Model Guide](adding_omni_model.md) and [Adding a New Diffusion Model Guide](adding_diffusion_model.md).
# Adding a Diffusion Model
This guide walks through the process of adding a new Diffusion model to vLLM-Omni, using Qwen/Qwen-Image-Edit as a comprehensive example.
# Table of Contents
1. [Overview](#overview)
2. [Directory Structure](#directory-structure)
3. [Step-by-Step Implementation](#step-by-step-implementation)
4. [Testing](#testing)
5. [Adding a Model Recipe](#adding-a-model-recipe)
# Overview
When add a new diffusion model into vLLM-Omni, additional adaptation work is required due to the following reasons:
+ New model must follow the framework’s parameter passing mechanisms and inference flow.
+ Replacing the model’s default implementations with optimized modules, which is necessary to achieve the better performance.
The diffusion execution flow as follow:
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-diffusion-flow.png">
<img alt="Diffusion Flow" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-diffusion-flow.png" width=55%>
</picture>
</p>
# Directory Structure
File Structure for Adding a New Diffusion Model
```
vllm_omni/
└── examples/
└──offline_inference
└── example script # reuse existing if possible (e.g., image_edit.py)
└──online_serving
└── example script
└── diffusion/
└── registry.py # Registry work
├── request.py # Request Info
└── models/your_model_name/ # Model directory (e.g., qwen_image)
└── pipeline_xxx.py # Model implementation (e.g., pipeline_qwen_image_edit.py)
```
# Step-by-step-implementation
## Step 1: Model Implementation
The diffusion pipeline’s implementation follows **HuggingFace Diffusers**.
### 1.1 Define the Pipeline Class
Define the pipeline class, e.g., `QwenImageEditPipeline`, and initialize all required submodules, either from HuggingFace `diffusers` or custom implementations. In `QwenImageEditPipeline`, only `QwenImageTransformer2DModel` is re-implemented to support optimizations such as Ulysses-SP. When adding new models in the future, you can either reuse this re-implemented `QwenImageTransformer2DModel` or extend it as needed.
### 1.2 Pre-Processing and Post-Processing Extraction
Extract the pre-processing and post-processing logic from the pipeline class to follow vLLM-Omni’s execution flow. For Qwen-Image-Edit:
```python
def get_qwen_image_edit_pre_process_func(
od_config: OmniDiffusionConfig,
):
"""
Define a pre-processing function that resizes input images and
pre-process for subsequent inference.
"""
```
```python
def get_qwen_image_edit_post_process_func(
od_config: OmniDiffusionConfig,
):
"""
Defines a post-processing function that post-process images.
"""
```
### 1.3 Define the forward function
The forward function of `QwenImageEditPipeline` follows the HuggingFace `diffusers` design for the most part. The key differences are:
+ As described in the overview, arguments are passed through `OnniDiffusionRequest`, so we need to get user parameters from it accordingly.
```python
prompt = req.prompt
```
+ pre/post-processing are handled by the framework elsewhere, so skip them.
### 1.4 Replace some ops or layers in DiT component
vLLM-Omni provides a set of optimized operators with better performance and built-in support for parallelism, including attention, rotary embeddings (RoPE), and linear layers.
Below is an example showing how to replace standard Transformer attention and FFN layers with vLLM-Omni implementations:
```python
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
class MyAttention(nn.Module):
def __init__(self):
super().__init__()
self.attn = Attention()
self.to_qkv = QKVParallelLinear()
self.to_out = RowParallelLinear()
self.rope = RotaryEmbedding(is_neox_style=False)
def forward(self, hidden_states):
qkv, _ = self.to_qkv(hidden_states)
q, k, v = qkv.split(...)
q, k = self.rope(...)
attn_output = self.attn(q, k, v)
output = self.to_out(attn_output)
class MyFFN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = ColumnParallelLinear()
self.fc2 = RowParallelLinear()
self.act = F.gelu
def forward(self, hidden_states):
hidden, _ = self.fc1(hidden_states)
hidden = self.act(hidden)
output = self.fc2(hidden)
return output
```
In this example:
+ Attention uses vLLM-Omni’s optimized attention kernel together with parallel QKV projection and RoPE.
+ Linear layers are replaced with column- and row-parallel variants to enable tensor parallelism.
+ The FFN follows a standard two-layer structure and can be further optimized (e.g., using fused or merged projections) if needed.
### 1.5 Provide a `_repeated_blocks` in DiT model
`_repeated_blocks` is the small and frequently-repeated block(s) of a model -- typically a transformer layer.
It's used for torch compile optimizations.
```python
_repeated_blocks = ["QwenImageTransformerBlock"]
```
### 1.6 (Optional) implement sequence parallelism
vLLM-Omni has a non-intrusive `_sp_plan` that enable sequence parallel without modifying `forward()` logic.
You can refer to [How to parallelize a new model](../../user_guide/diffusion/parallelism_acceleration.md)
### 1.7 (Optional) integrate with Cache-Dit
vLLM-Omni supports acceleration via [Cache-Dit](../../user_guide/diffusion/cache_dit_acceleration.md). Most models compatible with Diffusers can use Cache-Dit seamlessly. For new models, you can extend support by modifying`cache_dit_backend.py`
## Step 2: Extend OmniDiffusionRequest Fields
User-provided inputs are ultimately passed to the model’s forward method through OmniDiffusionRequest, so we add the required fields here to support the new model.
```python
prompt: str | list[str] | None = None
negative_prompt: str | list[str] | None = None
...
```
## Step 3: Registry
+ registry diffusion model in registry.py
```python
_DIFFUSION_MODELS = {
# arch:(mod_folder, mod_relname, cls_name)
...
"QwenImageEditPipeline": (
"qwen_image",
"pipeline_qwen_image_edit",
"QwenImageEditPipeline",
),
...
}
```
+ registry pre-process get function
```python
_DIFFUSION_PRE_PROCESS_FUNCS = {
# arch: pre_process_func
...
"QwenImageEditPipeline": "get_qwen_image_edit_pre_process_func",
...
}
```
+ registry post-process get function
```python
_DIFFUSION_POST_PROCESS_FUNCS = {
# arch: post_process_func
...
"QwenImageEditPipeline": "get_qwen_image_edit_post_process_func",
...
}
```
## Step 4: Add an Example Script
For each newly integrated model, we need to provide examples script under the examples/ to demonstrate how to initialize the pipeline with Omni, pass in user inputs, and generate outputs.
Key point for writing the example:
+ Use the Omni entrypoint to load the model and construct the pipeline.
+ Show how to format user inputs and pass them via omni.generate(...).
+ Demonstrate the common runtime arguments, such as:
+ model path or model name
+ input image(s) or prompt text
+ key diffusion parameters (e.g., inference steps, guidance scale)
+ optional acceleration backends (e.g., Cache-DiT, TeaCache)
+ Save or display the generated results so users can validate the integration.
## Step 5: TeaCache Coefficient Estimation (Optional)
If your model supports TeaCache acceleration, you need to estimate the polynomial coefficients for optimal caching performance.
### 5.1 Add Extractor Function
First, implement an extractor function in `vllm_omni/diffusion/cache/teacache/extractors.py`. The extractor extracts the modulated input and defines how to run transformer blocks:
```python
def extract_your_model_context(
module: nn.Module,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
**kwargs: Any,
) -> CacheContext:
# 1. Preprocessing
temb = module.time_embed(timestep)
# 2. Extract modulated input (for cache decision)
modulated_input = module.transformer_blocks[0].norm1(hidden_states, temb)
# 3. Define transformer execution
def run_transformer_blocks():
h = hidden_states
for block in module.transformer_blocks:
h = block(h, temb=temb)
return (h,)
# 4. Define postprocessing
def postprocess(h):
return module.proj_out(module.norm_out(h, temb))
return CacheContext(
modulated_input=modulated_input,
hidden_states=hidden_states,
encoder_hidden_states=None,
temb=temb,
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
)
```
Register it in `EXTRACTOR_REGISTRY`:
```python
EXTRACTOR_REGISTRY = {
...
"YourTransformer2DModel": extract_your_model_context,
}
```
### 5.2 Add Adapter for Coefficient Estimation
Add an adapter in `vllm_omni/diffusion/cache/teacache/coefficient_estimator.py`:
```python
class YourModelAdapter:
@staticmethod
def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
# Load your pipeline
...
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.transformer, "YourTransformer2DModel"
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)
_MODEL_ADAPTERS["YourModel"] = YourModelAdapter
```
### 5.3 Run Coefficient Estimation
Use the provided script to estimate coefficients:
```python
from vllm_omni.diffusion.cache.teacache.coefficient_estimator import (
TeaCacheCoefficientEstimator,
)
from datasets import load_dataset
from tqdm import tqdm
# Load model
estimator = TeaCacheCoefficientEstimator(
model_path="/path/to/model",
model_type="Bagel", # Your model type
device="cuda",
)
# Load prompts (paper suggests ~70 prompts)
dataset = load_dataset("nateraw/parti-prompts", split="train")
prompts = dataset["Prompt"][:70]
# Collect data
for prompt in tqdm(prompts):
estimator.collect_from_prompt(prompt, num_inference_steps=50)
# Estimate coefficients
coeffs = estimator.estimate(poly_order=4)
print(f"Coefficients: {coeffs}")
```
### 5.4 Interpreting Coefficient Estimation Results
The estimator outputs statistics and polynomial coefficients. Here's how to interpret them:
**Example Output:**
```
Data statistics:
Count: 48
Input Diffs (x): min=1.1089e-02, max=5.2555e-02, mean=2.8435e-02
Output Diffs (y): min=2.8242e-02, max=2.9792e-01, mean=7.0312e-02
Coefficients: [1333131.29, -168644.23, 7950.51, -163.75, 1.26]
```
**What to Check:**
- **Count**: Number of timestep pairs analyzed. Should be at least 30-50 for reliable estimation. Low count suggests insufficient prompts or inference steps.
- **Input/Output Ranges**: Verify output differences correlate with input differences. If ranges seem unusual, check your prompt diversity.
- **Coefficient Magnitude**: Extremely large values (>1e8) may indicate numerical instability - try collecting more diverse data.
**Troubleshooting:**
- If results seem unreliable, try:
- Increasing number of prompts (100+ recommended)
- Using more diverse prompts from multiple datasets
- Adjusting `num_inference_steps` (try 20, 50, 100)
### 5.5 Add Coefficients to Config
Add the estimated coefficients to `vllm_omni/diffusion/cache/teacache/config.py`:
```python
_MODEL_COEFFICIENTS = {
...
"YourTransformer2DModel": [
1.04730573e+06, # a4
-1.34150749e+05, # a3
6.51517806e+03, # a2
-1.41209108e+02, # a1
1.17241808e+00, # a0
],
}
```
## Step 6: Open a Pull Request
When submitting a pull request to add support for a new model, please include the following information in the PR description:
+ Output verification: provide generation outputs to verify correctness and model behavior.
+ Inference speed: provide a comparison with the corresponding implementation in Diffusers.
+ Parallelism support: specify the supported parallel sizes and any relevant limitations.
+ Cache acceleration: check whether the model can be accelerated using Cache-Dit or not.
Providing these details helps reviewers evaluate correctness, performance improvements, and parallel scalability of the new model integration.
# Testing
For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md).
## Adding a Model Recipe
After implementing and testing your model, please add a model recipe to the [vllm-project/recipes](https://github.com/vllm-project/recipes) repository. This helps other users understand how to use your model with vLLM-Omni.
# Adding an Omni-Modality Model
This guide walks through the process of adding a new multi-stage model to vLLM-Omni, using **Qwen3-Omni** as a comprehensive example. Qwen3-Omni is a multi-stage omni-modality model that demonstrates the full capabilities of vLLM-Omni's architecture.
## Table of Contents
1. [Overview](#overview)
2. [Directory Structure](#directory-structure)
3. [Step-by-Step Implementation](#step-by-step-implementation)
4. [Key Components](#key-components)
5. [Model Registration](#model-registration)
6. [Stage Configuration](#stage-configuration)
7. [Stage Input Processors](#stage-input-processors)
8. [Testing](#testing)
9. [Adding a Model Recipe](#adding-a-model-recipe)
10. [Summary](#summary)
## Overview
vLLM-Omni supports multi-stage model architectures where different stages can run on different devices and process different modalities. The Qwen3-Omni model exemplifies this with three stages:
1. **Thinker Stage**: Multimodal understanding (text + audio + video) → text generation
2. **Talker Stage**: Text embeddings → RVQ codec codes
3. **Code2Wav Stage**: RVQ codes → audio waveform
Each stage is implemented as a separate model class that can be configured independently.
## Directory Structure
When adding a new model, you'll need to create the following structure:
```
vllm_omni/model_executor/models/
└── your_model_name/ # Model directory (e.g., qwen3_omni)
├── __init__.py # Exports main model class
├── your_model.py # Main unified model class
├── your_model_stage1_implementation.py # Stage 1 implementation (e.g., thinker)
├── your_model_stage2_implementation.py # Stage 2 implementation (e.g., talker)
└── your_model_stage3_implementation.py # Stage 3 implementation (e.g., code2wav)
└── ... maybe other stage implementations
vllm_omni/model_executor/stage_input_processors/
└── your_model_name.py # Stage transition processors
vllm_omni/model_executor/stage_configs/
└── your_model_name.yaml # Stage configuration file
```
## Step-by-Step Implementation
### Step 1: Create the Model Directory
Create a new directory under `vllm_omni/model_executor/models/`
### Step 2: Implement Stage Components
For Qwen3-Omni, we have three stage components:
#### 2.1 Thinker Stage (`qwen3_omni_moe_thinker.py`)
The thinker stage handles multimodal understanding. Key features:
- Inherits from base Qwen3 MoE model in vLLM, using vLLM fused ops & page attn to accelerate
- Implements multimodal processing interfaces
- Handles audio, video, and image inputs
- Generates text outputs
```python
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
class Qwen3OmniMoeThinkerForConditionalGeneration(
Qwen3MoeForCausalLM,
SupportsMultiModal,
SupportsPP
):
"""Thinker stage: multimodal understanding → text generation."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Initialize base model
# Set up multimodal processors
# Configure audio/video/image encoders
pass
```
#### 2.2 Talker Stage (`qwen3_omni_moe_talker.py`)
The talker stage converts text embeddings to codec codes:
```python
class Qwen3OmniMoeTalkerForConditionalGeneration(
Qwen3MoeForCausalLM,
SupportsPP
):
"""Talker stage: text embeddings → RVQ codec codes."""
def __init__(self, vllm_config, talker_config, prefix):
# Initialize base model
# Replace LM head with codec head
# Set up text projection from thinker
pass
```
#### 2.3 Code2Wav Stage (`qwen3_omni_code2wav.py`)
The code2wav stage generates audio waveforms:
```python
class Qwen3OmniMoeCode2Wav(nn.Module):
"""Code2Wav stage: RVQ codes → audio waveform."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Initialize audio decoder
# Set up codec processing
pass
```
### Step 3: Implement the Unified Model Class
The main model class (`qwen3_omni.py`) orchestrates all stages:
```python
@MULTIMODAL_REGISTRY.register_processor(
Qwen3OmniMoeThinkerMultiModalProcessor,
info=Qwen3OmniMoeThinkerProcessingInfo,
dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder,
)
class Qwen3OmniMoeForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, Qwen3OmniMoeConditionalGenerationMixin
):
"""
Unified Qwen3 Omni MoE model combining thinker, talker, and code2wav.
Architecture:
- Thinker: Multimodal understanding (text + audio + video) → text generation
- Talker: Text embeddings → RVQ codec codes
- Code2Wav: RVQ codes → audio waveform
Usage:
Set `model_stage` in vllm_config to one of: "thinker", "talker", "code2wav"
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.have_multimodal_outputs = True
config: Qwen3OmniMoeConfig = vllm_config.model_config.hf_config
# Determine which stage to initialize
self.model_stage = vllm_config.model_config.model_stage
if self.model_stage == "thinker":
# Initialize thinker model
thinker_vllm_config = vllm_config.with_hf_config(
config.thinker_config,
architectures=["Qwen3OmniMoeThinkerForConditionalGeneration"]
)
self.thinker = init_vllm_registered_model(
vllm_config=thinker_vllm_config,
prefix=maybe_prefix(prefix, "thinker"),
hf_config=config.thinker_config,
architectures=["Qwen3OmniMoeThinkerForConditionalGeneration"],
)
self.model = self.thinker
elif self.model_stage == "talker":
# Initialize talker model
talker_vllm_config = vllm_config.with_hf_config(
config.talker_config,
architectures=["Qwen3OmniMoeTalkerForConditionalGeneration"]
)
self.talker = init_vllm_registered_model(
vllm_config=talker_vllm_config,
prefix=maybe_prefix(prefix, "talker"),
hf_config=config.talker_config,
architectures=["Qwen3OmniMoeTalkerForConditionalGeneration"],
)
self.model = self.talker
elif self.model_stage == "code2wav":
# Initialize code2wav model
code2wav_vllm_config = vllm_config.with_hf_config(
config.code2wav_config,
architectures=["Qwen3OmniMoeCode2Wav"]
)
self.code2wav = init_vllm_registered_model(
vllm_config=code2wav_vllm_config,
prefix=maybe_prefix(prefix, "code2wav"),
hf_config=config.code2wav_config,
architectures=["Qwen3OmniMoeCode2Wav"],
)
self.model = self.code2wav
else:
raise ValueError(
f"Invalid model_stage: {self.model_stage}. "
f"Must be one of: 'thinker', 'talker', 'code2wav'"
)
```
#### Key Methods to Implement
1. **`forward()`**: Handles the forward pass for each stage
2. **`embed_input_ids()`**: Embeds input token IDs
3. **`embed_multimodal()`**: Processes multimodal inputs (if applicable)
4. **`compute_logits()`**: Computes logits from hidden states
5. **`load_weights()`**: Loads model weights with proper prefixing of different stages
### Step 4: Create `__init__.py`
Export the main model class:
```python
# vllm_omni/model_executor/models/qwen3_omni/__init__.py
from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration
__all__ = ["Qwen3OmniMoeForConditionalGeneration"]
```
## Key Components
### 1. Model Interfaces
Your model should implement the appropriate interfaces:
- **`SupportsMultiModal`**: For models that process multimodal inputs
- **`SupportsPP`**: For models that support pipeline parallelism
- **`SupportsMRoPE`**: For models using multi-dimensional RoPE (if applicable)
### 2. Multimodal Registration
If your model processes multimodal inputs, register it with the multimodal registry:
```python
@MULTIMODAL_REGISTRY.register_processor(
YourMultiModalProcessor,
info=YourProcessingInfo,
dummy_inputs=YourDummyInputsBuilder,
)
class YourModel(nn.Module, SupportsMultiModal):
pass
```
### 3. Weight Loading
Implement `load_weights()` to handle weight loading with proper prefixing:
```python
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights for all components of the omni model."""
loaded_weights = set()
thinker_weights = []
talker_weights = []
code2wav_weights = []
# Separate weights by component
for k, v in weights:
if k.startswith("thinker."):
thinker_weights.append((k, v))
elif k.startswith("talker."):
talker_weights.append((k, v))
elif k.startswith("code2wav."):
code2wav_weights.append((k, v))
# Load each component's weights
if self.thinker and thinker_weights:
thinker_loaded = self.thinker.load_weights(thinker_weights)
thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker")
loaded_weights.update(thinker_loaded)
# Similar for talker and code2wav...
return loaded_weights
```
### 4. Output Format
Use `OmniOutput` for stage outputs:
```python
from vllm_omni.model_executor.models.output_templates import OmniOutput
# In forward method
return OmniOutput(
text_hidden_states=hidden_states,
multimodal_outputs={"additional_data": data},
next_token_id=next_token_id,
)
```
## Model Registration
Register your model in `vllm_omni/model_executor/models/registry.py`:
```python
_OMNI_MODELS = {
# ... existing models ...
# Your new model
"YourModelForConditionalGeneration": (
"your_model_name", # Module folder name
"your_model", # Module file name (without .py)
"YourModelForConditionalGeneration", # Class name
),
"YourModelThinkerForConditionalGeneration": (
"your_model_name",
"your_model_thinker",
"YourModelThinkerForConditionalGeneration",
),
# ... other stages ...
}
```
The registry uses lazy loading, so the model class is imported only when needed.
## Stage Configuration
Create a YAML configuration file in `vllm_omni/model_executor/stage_configs/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml).
### Key Configuration Fields
- **`model_stage`**: Which stage to run ("thinker", "talker", "code2wav", etc.)
- **`model_arch`**: The model architecture name (must match registry)
- **`engine_input_source`**: List of stage IDs that provide input to this stage
- **`custom_process_input_func`**: Function to process inputs from previous stages
- **`final_output`**: Whether this stage produces the final output (True/False)
- **`final_output_type`**: Type of final output ("text", "audio", "image", etc.)
## Stage Input Processors
Stage transitions are the mechanism by which outputs from one stage are converted into inputs for the next stage. This section explains where and how stage transitions occur.
### Where Stage Transitions Are Called
Stage transitions happen automatically in the orchestrator (`OmniLLM` class) during the generation loop. Here's the detailed flow:
1. **Location**: `vllm_omni/entrypoints/omni_llm.py` in the `_run_generation()` method
2. **Trigger**: When a stage completes processing and produces outputs
3. **Execution Flow**:
```python
# In omni_llm.py, _run_generation() method (around line 345-460)
# Main orchestrator loop polls each stage for completed requests
for stage_id, stage in enumerate(self.stage_list):
result = stage.try_collect() # Get completed request
if result is None:
continue
# Store outputs from this stage
engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm")
stage.set_engine_outputs(engine_outputs)
# Check if there's a next stage to forward to
next_stage_id = stage_id + 1
if next_stage_id < len(self.stage_list):
next_stage: OmniStage = self.stage_list[next_stage_id]
# THIS IS WHERE STAGE TRANSITION HAPPENS
next_inputs = next_stage.process_engine_inputs(
self.stage_list,
[request_id_to_prompt[req_id]]
)
# Submit to next stage
task = {
"type": OmniStageTaskType.GENERATE,
"request_id": req_id,
"engine_inputs": next_inputs[0],
"sampling_params": sampling_params_list[next_stage_id],
}
next_stage.submit(task)
```
### How Stage Transitions Work
The stage transition process follows these steps:
1. **Stage Completion**: When a stage finishes processing a request, it stores outputs via `stage.set_engine_outputs(engine_outputs)`
2. **Transition Detection**: The orchestrator checks if there's a next stage and calls `process_engine_inputs()` on it
3. **Input Processing**: The `process_engine_inputs()` method in `OmniStage` (`omni_stage.py`) handles the transition:
```python
def process_engine_inputs(
self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None
) -> list[OmniTokensPrompt | TextPrompt]:
"""Process engine inputs for this stage from upstream stage outputs."""
if self.custom_process_input_func is None:
# Default behavior: pass token IDs directly
# Extract outputs from source stage
source_stage_id = self.engine_input_source[0]
source_outputs = stage_list[source_stage_id].engine_outputs
# ... create OmniTokensPrompt from token_ids ...
else:
# Custom transition function (YOUR CODE HERE)
return self.custom_process_input_func(
stage_list,
self.engine_input_source,
prompt,
self.requires_multimodal_data
)
```
- If `custom_process_input_func` is configured, it calls that function
- Otherwise, it uses default behavior (passing token IDs directly)
4. **Custom Function Execution**: Your custom function receives:
- `stage_list`: List of all stage objects (to access upstream stage outputs)
- `engine_input_source`: List of source stage IDs (e.g., `[0]` for stage 0)
- `prompt`: Original prompt data (for preserving multimodal data)
- `requires_multimodal_data`: Whether multimodal data is required
5. **Output Format**: The function must return a list of `OmniTokensPrompt` objects ready for the next stage
### Data Structures in Stage Transitions
Understanding the data structures is crucial for implementing stage transitions:
**Input to your function:**
- `stage_list[source_stage_id].engine_outputs`: List of `EngineCoreOutput` objects
- Each contains `outputs`: List of `RequestOutput` objects
- Each `RequestOutput` has:
- `token_ids`: Generated token IDs
- `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.
- These are the hidden states or intermediate outputs from the model's forward pass
- `prompt_token_ids`: Original prompt token IDs
**Output from your function:**
- Must return `list[OmniTokensPrompt]` where each `OmniTokensPrompt` contains:
- `prompt_token_ids`: List[int] - Token IDs for the next stage
- `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states)
- `multi_modal_data`: Optional multimodal data if needed
### How Model Outputs Are Stored
The model's `forward()` method returns an `OmniOutput` object that contains:
- `text_hidden_states`: Final hidden states for text generation
- `multimodal_outputs`: Dict containing intermediate outputs
These outputs are captured during the forward pass and stored in `multimodal_output` with specific keys:
```python
# In your model's forward() method (e.g., qwen3_omni.py)
def forward(self, ...):
# ... processing ...
# For thinker stage: capture embeddings and hidden states
multimodal_outputs = {
"0": captured_embeddings, # Layer 0 embeddings
"24": captured_hidden_states, # Layer 24 hidden states
"tts_bos_embed": tts_bos_embed,
"tts_eos_embed": tts_eos_embed,
# ... other intermediate outputs ...
}
return OmniOutput(
text_hidden_states=hidden_states,
multimodal_outputs=multimodal_outputs,
)
```
These keys are then accessible in your stage transition function:
```python
# In stage_input_processors/qwen3_omni.py
thinker_embeddings = output.multimodal_output["0"] # Access by key
thinker_hidden_states = output.multimodal_output["24"]
```
### Key Points
1. **Accessing Upstream Outputs**: Use `stage_list[source_stage_id].engine_outputs` to get outputs from the source stage
2. **Extracting Data**: Access `output.multimodal_output[key]` to get specific hidden states or intermediate results
- Keys are defined by your model's `forward()` method when it creates `multimodal_outputs`
3. **Device Management**: Move tensors to appropriate devices (CPU for serialization, GPU for processing)
4. **Shape Transformations**: Reshape tensors as needed for the next stage (e.g., flattening codec codes)
5. **Batch Handling**: Process each request in the batch separately and return a list
### Complete Flow Diagram
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-dataflow-between-stages.png">
<img alt="Data Flow between stages" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-dataflow-between-stages.png" width=55%>
</picture>
</p>
### Implementation Example
Create stage transition processors in `vllm_omni/model_executor/stage_input_processors/your_model_name.py`:
```python
# qwen3_omni.py
def thinker2talker(
stage_list: list[Any],
engine_input_source: list[int],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
) -> list[OmniTokensPrompt]:
"""
Process thinker outputs to create talker inputs.
Args:
stage_list: List of stage objects
engine_input_source: Source stage IDs (typically [0] for thinker)
prompt: Original prompt data
Returns:
List of OmniTokensPrompt for talker stage
"""
source_stage_id = engine_input_source[0]
thinker_outputs = stage_list[source_stage_id].engine_outputs
talker_inputs = []
for thinker_output in thinker_outputs:
output = thinker_output.outputs[0]
# Extract thinker embeddings and hidden states
thinker_embeddings = output.multimodal_output["0"].float().clone().detach().cuda()
thinker_hidden_states = output.multimodal_output["24"].float().clone().detach().cuda()
info = {
"thinker_embeddings": thinker_embeddings,
"thinker_hidden_states": thinker_hidden_states,
"thinker_sequences": thinker_output.prompt_token_ids + output.token_ids,
"thinker_input_ids": thinker_output.prompt_token_ids,
}
talker_inputs.append(
OmniTokensPrompt(
prompt_token_ids=[0] * computed_length,
additional_information=info,
multi_modal_data=None,
)
)
return talker_inputs
def talker2code2wav(
stage_list: list[Any],
engine_input_source: list[int],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
) -> list[OmniTokensPrompt]:
"""
Process talker outputs to create code2wav inputs.
"""
source_stage_id = engine_input_source[0]
talker_outputs = stage_list[source_stage_id].engine_outputs
code2wav_inputs = []
for talker_output in talker_outputs:
output = talker_output.outputs[0]
# Extract codec codes
codec_codes = (
output.multimodal_output["code_predictor_codes"]
.to(torch.long)
.transpose(0, 1)
.cpu()
.to(torch.long)
.reshape(-1)
.tolist()
)
code2wav_inputs.append(
OmniTokensPrompt(
prompt_token_ids=codec_codes,
multi_modal_data=None,
)
)
return code2wav_inputs
```
## Testing
For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md).
## Adding a Model Recipe
After implementing and testing your model, please add a model recipe to the [vllm-project/recipes](https://github.com/vllm-project/recipes) repository. This helps other users understand how to use your model with vLLM-Omni.
### What to Include
Your recipe should include:
1. **Model Overview**: Brief description of the model and its capabilities
2. **Installation Instructions**: Step-by-step setup instructions including:
- Installing vllm-omni and dependencies
- Installing any additional required packages (e.g., xformers, diffusers)
- Any version requirements
3. **Usage Examples**: Command-line examples demonstrating how to run the model
4. **Configuration Details**: Important configuration parameters and their meanings
### Example
For reference, see the [LongCat recipe example](https://github.com/vllm-project/recipes/pull/179) which demonstrates the expected format and structure.
### Recipe Location
Create your recipe file in the appropriate directory structure:
- For organization-specific models: `OrganizationName/ModelName.md`
- For general models: `ModelName.md`
The recipe should be a Markdown file that provides clear, reproducible instructions for users to get started with your model.
## Summary
Adding a new model to vLLM-Omni involves:
1. **Create model directory structure** with stage implementations
2. **Implement unified model class** that orchestrates stages
3. **Register model** in `registry.py`
4. **Create stage configuration** YAML file
5. **Implement stage input processors** for stage transitions
6. **Write tests** to verify functionality
7. **Add model recipe** to the [vllm-project/recipes](https://github.com/vllm-project/recipes) repository (see [Adding a Model Recipe](#adding-a-model-recipe) section)
### Qwen3-Omni Reference Files
For a complete reference implementation, see:
- **Main model**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py`
- **Thinker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py`
- **Talker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py`
- **Code2Wav**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py`
- **Stage config**: `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`
- **Input processors**: `vllm_omni/model_executor/stage_input_processors/qwen3_omni.py`
- **Registry**: `vllm_omni/model_executor/models/registry.py`
- **Testing**: `vllm_omni/tests/e2e/offline_inference/test_qwen3_omni.py`
For more information, see:
- [Architecture Overview](../../design/architecture_overview.md)
- [Supported Models](../../models/supported_models.md)
- [Stage Configuration Guide](../../configuration/stage_configs.md)
# Profiling vLLM-Omni
> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production.
vLLM-Omni uses the PyTorch Profiler to analyze performance across both **multi-stage omni-modality models** and **diffusion models**.
### 1. Set the Output Directory
Before running any script, set this environment variable. The system detects this and automatically saves traces here.
```bash
export VLLM_TORCH_PROFILER_DIR=./profiles
```
### 2. Profiling Omni-Modality Models
It is best to limit profiling to one iteration to keep trace files manageable.
```bash
export VLLM_PROFILER_MAX_ITERS=1
```
**Selective Stage Profiling**
The profiler is default to function across all stages. But It is highly recommended to profile specific stages by passing the stages list, preventing from producing too large trace files:
```python
# Profile all stages
omni_llm.start_profile()
# Only profile Stage 1
omni_llm.start_profile(stages=[1])
```
```python
# Stage 0 (Thinker) and Stage 2 (Audio Decoder) for qwen omni
omni_llm.start_profile(stages=[0, 2])
```
**Python Usage**: Wrap your generation logic with `start_profile()` and `stop_profile()`.
```python
from vllm_omni import omni_llm
profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
# 1. Start profiling if enabled
if profiler_enabled:
omni_llm.start_profile(stages=[0])
# Initialize generator
omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator)
total_requests = len(prompts)
processed_count = 0
# Main Processing Loop
for stage_outputs in omni_generator:
# ... [Output processing logic for text/audio would go here] ...
# Update count to track when to stop profiling
processed_count += len(stage_outputs.request_output)
# 2. Check if all requests are done to stop the profiler safely
if profiler_enabled and processed_count >= total_requests:
print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
# Stop the profiler while workers are still active
omni_llm.stop_profile()
# Wait for traces to flush to disk
print("[Info] Waiting 30s for workers to write trace files to disk...")
time.sleep(30)
print("[Info] Trace export wait time finished.")
omni_llm.close()
```
**Examples**:
1. **Qwen2.5-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py)
2. **Qwen3-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py)
### 3. Profiling diffusion models
Diffusion profiling is End-to-End, capturing encoding, denoising loops, and decoding.
**CLI Usage:**
```python
python image_to_video.py \
--model Wan-AI/Wan2.2-I2V-A14B-Diffusers \
--image qwen-bear.png \
--prompt "A cat playing with yarn, smooth motion" \
\
# Minimize Spatial Dimensions (Optional but helpful):
# Drastically reduces memory usage so the profiler doesn't
# crash due to overhead, though for accurate performance
# tuning you often want target resolutions.
--height 48 \
--width 64 \
\
# Minimize Temporal Dimension (Frames):
# Video models process 3D tensors (Time, Height, Width).
# Reducing frames to the absolute minimum (2) keeps the
# tensor size small, ensuring the trace file doesn't become
# multi-gigabytes in size.
--num_frames 2 \
\
# Minimize Iteration Loop (Steps):
# This is the most critical setting for profiling.
# Diffusion models run the same loop X times.
# Profiling 2 steps gives you the exact same performance
# data as 50 steps, but saves minutes of runtime and
# prevents the trace viewer from freezing.
--num_inference_steps 2 \
\
--guidance_scale 5.0 \
--guidance_scale_high 6.0 \
--boundary_ratio 0.875 \
--flow_shift 12.0 \
--fps 16 \
--output i2v_output.mp4
```
**Examples**:
1. **Qwen image edit**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py)
2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video)
> **Note:**
As of now, asynchronous (online) profiling is not fully supported in vLLM-Omni. While start_profile() and stop_profile() methods exist, they are only reliable in offline inference scripts (e.g., the provided end2end.py examples). Do not use them in server-mode or streaming scenarios—traces may be incomplete or fail to flush.
### 4. Analyzing Omni Traces
Output files are saved to your configured ```VLLM_TORCH_PROFILER_DIR```.
**Output**
**Chrome Trace** (```.json.gz```): Visual timeline of kernels and stages. Open in Perfetto UI.
**Viewing Tools:**
- [Perfetto](https://ui.perfetto.dev/)(recommended)
- ```chrome://tracing```(Chrome only)
**Note**: vLLM-Omni reuses the PyTorch Profiler infrastructure from vLLM. See the official vLLM profiler documentation: [vLLM Profiling Guide](https://docs.vllm.ai/en/stable/contributing/profiling/)
# Architecture Overview
This document outlines the architectural design for vLLM-Omni.
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/omni-modality-model-architecture.png">
<img alt="Omni-Modality Model Architecture" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/omni-modality-model-architecture.png" width=55%>
</picture>
</p>
# Goals
The primary goal of the vLLM-Omni project is to build the fastest and easiest-to-use open-source Omni-Modality model inference & serving engine. vLLM-Omni extends the original vLLM, which was created to support large language models for text-based autoregressive (AR) generation tasks. vLLM-Omni is designed to support:
* **Non-textual Output:** Enables the integration, efficient processing and output of various data types, including but not limited to, images, audio, and video, alongside text.
* **Non-Autoregressive Structure:** Support model structure beyond autoregressive, especially Diffusion Transformer (DiT), which is widely used in visual and audio generation.
* **Integration with vLLM Core:** Maintain compatibility and leverage existing vLLM key modules and optimizations where applicable.
* **Extensibility:** Design a modular and flexible architecture that can easily accommodate new modalities, model architectures, and output formats.
# Representative omni-modality models
According to analysis for current popular open-source models, most of them have the combination of AR+DiT. Specifically, they can be further categorized into 3 types below:
**DiT as a main structure, with AR as text encoder (e.g.: Qwen-Image)**
A powerful image generation foundation model capable of complex text rendering and precise image editing.
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/dit-main-architecture.png">
<img alt="Qwen-Image" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/dit-main-architecture.png" width=30%>
</picture>
</p>
**AR as a main structure, with DiT as multi-modal generator (e.g. BAGEL)**
A unified multimodal comprehension and generation model, with cot text output and visual generation.
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/ar-main-architecture.png">
<img alt="Bagel" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/ar-main-architecture.png" width=30%>
</picture>
</p>
**AR+DiT (e.g. Qwen-Omni)**
A natively end-to-end omni-modal LLM for multimodal inputs (text/image/audio/video...) and outputs (text/audio...).
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/ar-dit-main-architecture.png">
<img alt="Qwen-Omni" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/ar-dit-main-architecture.png" width=30%>
</picture>
</p>
# vLLM-Omni main architecture
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-main-architecture.png">
<img alt="vLLM-Omni Main Architecture" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-main-architecture.png" width=55%>
</picture>
</p>
## Key Components
| Component | Description |
| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| **OmniRouter** | provide an intelligent router for Omni-modality requests dispatch |
| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni) and provide the OmniStage abstraction for different AR/DiT stages |
| **AR** | adapted for omni-modality models while inheriting efficient features from vLLM, such as cache management |
| **Diffusion** | natively implemented and optimized using acceleration components |
| **OmniConnector** | supports fully disaggregation based on E/P/D/G (Encoding/Processing/Decoding/Generation) disaggregation across stages |
Disaggregated stages are managed through configuration, such as in the Qwen3-Omni example, where stages like Thinker, Talker, and Code2wav are defined as separate OmniStage instances with specific resources and input/output type.
## Main features
vLLM-Omni aims to be fast, flexible, and easy to use with the following features:
### Performance and Acceleration
The framework achieves high performance through several optimization techniques:
* **Efficient AR Support:** Leverages efficient KV cache management inherited from vLLM.
* **Pipelined Execution:** Uses pipelined stage execution overlapping to ensure high throughput.
* **Full Disaggregation:** Relies on the OmniConnector and dynamic resource allocation across stages.
* **Diffusion Acceleration:** Includes integrated support for diffusion acceleration. This is managed by the acceleration layer, which handles:
* **Cache:** Includes DBCache, TeaCache and third-party integration(e.g., [cache-dit](https://github.com/vipshop/cache-dit)).
* **Parallelism:** Supports TP, CP, USP, and CFG.
* **Attention:** Provides an interface for third-party integration (e.g., FA3, SAGE, MindIE-SD).
* **Quantization:** Supports various quantization implementations including FP8 and AWQ.
* **FusedOps:** Allows for custom and third-party integration.
### Flexibility and Usability
vLLM-Omni is designed to be flexible and straightforward for users:
* **Heterogeneous Pipeline Abstraction:** Manages complex model workflows effectively.
* **Hugging Face Integration:** Offers seamless integration with popular Hugging Face models.
* **Distributed Inference:** Supports tensor, pipeline, data, and expert parallelism.
* **Streaming Outputs:** Supports streaming outputs.
* **Unified API:** Provides a consistent and unified API interface compatible with vLLM.
* **OpenAI-compatible API Server:** Includes a FastAPI-based server for online serving that is compatible with the OpenAI API.
# Interface design
If you use vLLM, then you know how to use vLLM-Omni from Day 0:
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-user-interface.png">
<img alt="vLLM-Omni interface design" src="https://raw.githubusercontent.com/vllm-project/vllm-omni/refs/heads/main/docs/source/architecture/vllm-omni-user-interface.png" width=55%>
</picture>
</p>
Taking **Qwen3-Omni** as an example:
## Offline Inference
The **Omni** class provides a Python interface for offline batched inference. Users initialize the Omni class with a Hugging Face model name and use the generate method, passing inputs that include both text prompts and multi-modal data:
```
# Create an omni_lm with HF model name.
from vllm_omni.entrypoints.omni import Omni
omni_lm = Omni(model="Qwen/Qwen3-Omni-30B-A3B-Instruct")
# Example prompts.
om_inputs = {"prompt": prompt,
"multi_modal_data": {
"video": video_frames,
"audio": audio_signal,
}}
# Generate texts and audio from the multi-modality inputs.
outputs = omni_lm.generate(om_inputs, sampling_params_list)
```
## Online Serving
Similar to vLLM, vLLM-Omni also provides a FastAPI-based server for online serving. Users can launch the server using the vllm serve command with the `--omni` flag:
```
vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
```
Users can send requests to the server using curl:
```
# prepare user content
user_content='[
{
"type": "video_url",
"video_url": {
"url": "'"$SAMPLE_VIDEO_URL"'"
}
},
{
"type": "text",
"text": "Why is this video funny?"
}
]'
sampling_params_list='[
'"$thinker_sampling_params"',
'"$talker_sampling_params"',
'"$code2wav_sampling_params"'
]'
mm_processor_kwargs="{}"
# send the request
curl -sS -X POST http://localhost:8091/v1/chat/completions \
-H "Content-Type: application/json" \
-d @- <<EOF
{
"model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"sampling_params_list": $sampling_params_list,
"mm_processor_kwargs": $mm_processor_kwargs,
"messages": [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
}
]
},
{
"role": "user",
"content": $user_content
}
]
}
```
For more usages, please refer to [examples](https://github.com/vllm-project/vllm-omni/tree/main/examples).
# Disaggregated Inference for Omni-Modality Models
This guide explains how to configure and use distributed connectors
(`vllm_omni/distributed/omni_connectors`) in vllm-omni for multi-stage pipelines.
Backend-specific setup lives in separate docs:
- [SharedMemoryConnector](omni_connectors/shared_memory_connector.md)
- [MooncakeConnector](omni_connectors/mooncake_connector.md)
- [YuanrongConnector](omni_connectors/yuanrong_connector.md)
## Overview
Connectors enable data transfer between pipeline stages (e.g., Thinker -> Talker).
Current connectors operate in D2H2D (device to host to device) mode.
## Connector Choices
| Use Case | Recommended Connector | Notes |
| :--- | :--- | :--- |
| Single node | SharedMemoryConnector | Auto-configured if no connector is specified. |
| Multi node (Mooncake) | MooncakeConnector | Requires Mooncake Master + metadata server. |
| Multi node (Yuanrong) | YuanrongConnector | Requires Yuanrong Datasystem + etcd. |
## Core API
The connector system is built around `OmniConnectorBase`.
```python
class OmniConnectorBase(ABC):
@abstractmethod
def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[bool, int, Optional[dict]]:
"""
Store data.
Returns: (success, serialized_size, metadata)
"""
pass
@abstractmethod
def get(self, from_stage: str, to_stage: str, get_key: str, metadata: Optional[dict] = None) -> Optional[tuple[Any, int]]:
"""
Retrieve data.
Args: metadata - transport-specific handles returned by put() (e.g., SHM name).
Returns: (object, serialized_size)
"""
pass
```
### Metadata Passing
Some connectors (e.g., SharedMemoryConnector) generate transient resources during `put()`.
This `metadata` must be passed through the control plane so `get()` can locate the data.
## Configuration Model
Define connectors in runtime:
```yaml
runtime:
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 65536
```
Wire stages to connectors:
```yaml
stage_args:
- stage_id: 0
output_connectors:
to_stage_1: connector_of_shared_memory
- stage_id: 1
input_connectors:
from_stage_0: connector_of_shared_memory
```
If a pipeline edge has no explicit connector, the system auto-creates a
SharedMemoryConnector for that edge.
## Relationship with vLLM
vLLM provides specialized distributed mechanisms for specific artifacts:
- KV Transfer (`vllm.distributed.kv_transfer`): optimized for KV caches.
- EC Transfer (`vllm.distributed.ec_transfer`): optimized for encoder embeddings.
- Device Communicators (`vllm.distributed.device_communicators`): low-level primitives (NCCL, SHM).
vllm-omni complements this with a generalized connector abstraction:
1. Unifies transport via a single `put`/`get` API for any stage artifact.
2. Enables DAG-style pipelines across processes or nodes with per-edge transports.
3. Can wrap vLLM-specific transfers for KV paths while keeping a consistent interface.
## Operational Notes
- Fail-fast config validation: missing expected edges cause startup failures.
- Missing payloads halt stages: verify connector wiring and metadata propagation.
## Future Roadmap: D2D Transport
Current connectors use D2H2D paths. Future versions will introduce direct
device-to-device connectors (NCCL, UCX, IPC) to reduce latency for large
tensor payloads.
# MooncakeConnector
## When to Use
Best for multi-node distributed inference using Mooncake.
## Installation
```bash
# For CUDA-enabled systems (recommended)
pip install mooncake-transfer-engine
# For non-CUDA systems
pip install mooncake-transfer-engine-non-cuda
```
## Start Mooncake Master
```bash
# If you use Mooncake SSD storage
mkdir -p ./mc_storage
mooncake_master \
--rpc_port=50051 \
--enable_http_metadata_server=true \
--http_metadata_server_host=0.0.0.0 \
--http_metadata_server_port=8080 \
--metrics_port=9003 \
--root_fs_dir=./mc_storage/ \
--cluster_id=mc-local-1 &
```
## Configuration
Define the connector in runtime:
```yaml
runtime:
connectors:
connector_of_mooncake:
name: MooncakeConnector
extra:
host: "127.0.0.1"
metadata_server: "http://<MASTER_IP>:8080/metadata"
master: "<MASTER_IP>:50051"
segment: 512000000
localbuf: 64000000
proto: "tcp"
```
Wire stages to the connector:
```yaml
stage_args:
- stage_id: 0
output_connectors:
to_stage_1: connector_of_mooncake
- stage_id: 1
input_connectors:
from_stage_0: connector_of_mooncake
```
Parameters:
- host: local worker IP registered in the metadata server.
- metadata_server: metadata server URL for discovery and setup.
- master: Mooncake Master address.
- segment: global memory segment size in bytes.
- localbuf: local buffer size in bytes.
- proto: transport protocol ("tcp" or "rdma").
For more details, refer to the
[Mooncake repository](https://github.com/kvcache-ai/Mooncake).
# SharedMemoryConnector
## When to Use
Best for single-node deployments where stages run on the same host. It is
auto-configured when no explicit connector is specified for an edge.
## How It Works
- Small payloads (< threshold): serialized and passed inline in metadata.
- Large payloads (>= threshold): stored in shared memory; the SHM name is
returned in metadata.
## Configuration
```yaml
runtime:
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 65536
```
## Notes
- Auto-mode uses SharedMemoryConnector if no connector is declared for an edge.
# YuanrongConnector
## When to Use
Best for multi-node distributed inference using Yuanrong Datasystem.
## Mechanism
Uses Yuanrong Datasystem's distributed KV store (`datasystem.kv_client`).
- Data Plane: TCP or RDMA for high-bandwidth transfer.
- Control Plane: Yuanrong Datasystem workers and etcd.
- Keying: deterministic keys based on `put_key` (often composed as `request_id:fromStage_toStage`).
## Installation
```bash
pip install openyuanrong-datasystem
```
## Start etcd
```bash
# Download and install etcd (v3.5.12 or higher)
ETCD_VERSION="v3.5.12"
ETCD_ARCH="linux-arm64"
wget https://github.com/etcd-io/etcd/releases/download/${ETCD_VERSION}/etcd-${ETCD_VERSION}-${ETCD_ARCH}.tar.gz
tar -xvf etcd-${ETCD_VERSION}-${ETCD_ARCH}.tar.gz
cd etcd-${ETCD_VERSION}-${ETCD_ARCH}
sudo cp etcd etcdctl /usr/local/bin/
# Start etcd
etcd \
--name etcd-single \
--data-dir /tmp/etcd-data \
--listen-client-urls http://0.0.0.0:2379 \
--advertise-client-urls http://0.0.0.0:2379 \
--listen-peer-urls http://0.0.0.0:2380 \
--initial-advertise-peer-urls http://0.0.0.0:2380 \
--initial-cluster etcd-single=http://0.0.0.0:2380 &
# Verify etcd is running
etcdctl --endpoints "127.0.0.1:2379" put key "value"
etcdctl --endpoints "127.0.0.1:2379" get key
```
For production environments, refer to the
[official etcd clustering documentation](https://etcd.io/docs/current/op-guide/clustering/).
## Start Datasystem Worker
```bash
# Replace ${ETCD_IP} with etcd node IP, ${WORKER_IP} with local node IP
dscli start -w \
--worker_address "${WORKER_IP}:31501" \
--etcd_address "${ETCD_IP}:2379" \
--shared_memory_size_mb 20480
```
To stop the worker:
```bash
dscli stop --worker_address "${WORKER_IP}:31501"
```
## Configuration
Define the connector in runtime:
```yaml
runtime:
connectors:
connector_of_yuanrong:
name: YuanrongConnector
extra:
host: "127.0.0.1"
port: 31501
get_sub_timeout_ms: 1000
```
Wire stages to the connector:
```yaml
stage_args:
- stage_id: 0
output_connectors:
to_stage_1: connector_of_yuanrong
- stage_id: 1
input_connectors:
from_stage_0: connector_of_yuanrong
```
Parameters:
- host: datasystem worker host.
- port: datasystem worker port.
- get_sub_timeout_ms: get timeout in milliseconds (0 for no timeout).
For more details, refer to the
[Yuanrong Datasystem repository](https://atomgit.com/openeuler/yuanrong-datasystem).
# Distributed utils
This directory (vllm_omni/distributed/ray_utils) contains utilities for distributed execution in vllm-omni, supporting both **Ray** and **Multiprocessing** backends.
## 1. Installation
```bash
pip install "ray[default]"
```
## 2. Ray Utils
The `ray_utils` module provides helper functions for managing Ray clusters and actors, which is essential for:
* **Multi-node deployment**: Running pipeline stages across different physical machines.
* **Resource management**: Efficient GPU/CPU allocation.
### 2.1 Basic Usage
To use the Ray backend, specify `worker_backend="ray"` when initializing the engine.
**Command Line Example:**
```bash
vllm serve Qwen/Qwen2.5-Omni-7B \
--omni \
--port 8091 \
--worker-backend ray \
--ray-address auto
```
### 2.2 Cluster Setup
**Step 1: Start Head Node**
Run this on your primary machine:
```bash
ray start --head --port=6399
```
**Step 2: Connect Worker Nodes**
Run this on each worker machine:
```bash
ray start --address=<HEAD_NODE_IP>:6399
```
> **Tip**: For a complete cluster setup script, refer to the vLLM example:
> [run_cluster.sh](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/run_cluster.sh)
### 2.3 Distributed Connector Support
When running on Ray, the system automatically adapts its communication strategy:
* **Cross-Node**: Recommended to use `MooncakeConnector` (requires separate configuration).
* **Same-Node**: Can still use `SharedMemoryConnector` for efficiency, or Ray's native object store (plasma).
* **SHM threshold default differs**: when `worker_backend="ray"`, the SharedMemoryConnector default threshold is set to `sys.maxsize`, which forces payloads to go inline (no SHM). Override `shm_threshold_bytes` in the connector config if you want SHM for Ray runs.
### 2.4 Internal Helpers
* **`initialize_ray_cluster`**: Connects to an existing Ray cluster or starts a local one.
## 3. Troubleshooting
* **Connection Issues**: Ensure the Ray head node is accessible and ports (default 6399 in this example) are open.
* **Version Mismatch**: Ensure all nodes run the same version of Ray and Python.
# Design Documents
This section contains design documents and architecture specifications for vLLM-Omni.
## Architecture Documents
- [Architecture Overview](architecture_overview.md)
## Feature Design Documents
- [Disaggregated Inference](feature/disaggregated_inference.md)
- [Ray-based Execution](feature/ray_based_execution.md)
## Module Design Documents
- [AR Module](module/ar_module.md)
- [DIT Module](module/dit_module.md)
- [Entrypoint Module](module/entrypoint_module.md)
# AutoRegressive (AR) Module
## 1. Overview
The AutoRegressive (AR) module in vLLM-Omni handles autoregressive generation stages, primarily used for text, chain-of-thought(COT), and audio latent tokens generation stages in multi-stage models like Qwen2.5-Omni, Qwen3-Omni, BAGEL, .etc. Unlike some representative non-autoregressive generation stages (e.g., Diffusion), AR stages generate tokens sequentially, one at a time, following the standard transformer decoder pattern.
The AR module of vLLM-Omni extends vLLM's core components to support:
- **Multimodal inputs/outputs**: Processing images, videos, and audio alongside text
- **Direct embedding transfer**: Passing pre-computed prompt embeddings between pipeline stages via serialized payloads
- **Additional information flow**: Carrying per-request metadata (tensors, lists) through the pipeline
- **Hidden state exposure**: Exposing per-request hidden representations for downstream stages
- **Basic generator support**: Support some basic heterogeneous architecture such as Convolution, LSTM, etc.
As shown in the [end2end example](../../user_guide/examples/offline_inference/qwen3_omni.md), AR module can be widely applied across multiple stages, generating text tokens in thinker(AR), audio latent tokens in talker(AR) and audio wave in code2wav(Convolution).
## 2. Relationship with vLLM
The AR module builds upon vLLM main framework through inheritance, extending core classes while preserving compatibility with vLLM's scheduling, batching, KV cache management, and execution mechanisms.
### Inheritance Hierarchy
- Scheduler
```mermaid
classDiagram
class VLLMScheduler {
+schedule() SchedulerOutput
+update_from_output() EngineCoreOutputs
}
class OmniARScheduler {
+schedule() SchedulerOutput
}
class OmniGenerationScheduler {
+schedule() SchedulerOutput
+update_from_output() EngineCoreOutputs
}
VLLMScheduler <|-- OmniARScheduler
VLLMScheduler <|-- OmniGenerationScheduler
```
- Worker
```mermaid
classDiagram
class GPUWorker {
+init_device()
+model_runner
}
class GPUARWorker {
+init_device()
}
class GPUGenerationWorker {
+init_device()
}
GPUWorker <|-- GPUARWorker
GPUWorker <|-- GPUGenerationWorker
```
- ModelRunner
```mermaid
classDiagram
class GPUModelRunner {
+execute_model()
+sample_tokens()
}
class OmniGPUModelRunner {
+_update_states()
+_preprocess()
+_model_forward()
}
class GPUARModelRunner {
+execute_model()
+sample_tokens()
}
class GPUGenerationModelRunner {
+execute_model()
}
GPUModelRunner <|-- OmniGPUModelRunner
OmniGPUModelRunner <|-- GPUARModelRunner
OmniGPUModelRunner <|-- GPUGenerationModelRunner
```
- InputProcessor/OutputProcessor
```mermaid
classDiagram
class InputProcessor {
+process_inputs() EngineCoreRequest
}
class OmniInputProcessor {
+process_inputs() OmniEngineCoreRequest
}
class VLLMOutputProcessor {
+process_outputs() OutputProcessorOutput
}
class MultimodalOutputProcessor {
+process_outputs() OutputProcessorOutput
+_route_and_normalize()
}
InputProcessor <|-- OmniInputProcessor
VLLMOutputProcessor <|-- MultimodalOutputProcessor
```
### Key Extensions
- **Scheduler**: `OmniARScheduler` extends `vllm.v1.core.sched.scheduler.Scheduler` to enrich scheduled requests with omni-specific payloads
- **Worker**: `GPUARWorker` extends `vllm.v1.worker.gpu_worker.Worker` to initialize AR-specific model runners
- **ModelRunner**: `GPUARModelRunner` extends `OmniGPUModelRunner``vllm.v1.worker.gpu_model_runner.GPUModelRunner` to expose hidden states and handle multimodal outputs
- **InputProcessor**: `OmniInputProcessor` extends `vllm.v1.engine.input_processor.InputProcessor` to serialize prompt embeddings and additional information
- **OutputProcessor**: `MultimodalOutputProcessor` extends `vllm.v1.engine.output_processor.OutputProcessor` to route and accumulate multimodal outputs
## 3. Scheduler Design
The AR module provides two scheduler implementations: one for standard autoregressive generation and one for basic heterogeneous architectures.
### Request Flow
The following diagram illustrates the request flow through the AR module components:
```mermaid
flowchart TD
A[OmniInputProcessor] -->|OmniEngineCoreRequest| B[OmniARScheduler]
B -->|schedule: OmniNewRequestData| C[GPUARWorker]
C -->|SchedulerOutput| D[GPUARModelRunner]
D -->|execute_model: None| E[Model Forward Pass]
E -->|hidden_states, logits| D
D -->|sample_tokens: OmniModelRunnerOutput| F[OmniARScheduler]
F -->|update_from_output| G[MultimodalOutputProcessor]
G -->|RequestOutput| H[Client/Downstream Stage]
style A fill:#e1f5ff
style B fill:#fff4e1
style C fill:#e8f5e9
style D fill:#f3e5f5
style G fill:#fce4ec
```
The flow follows vLLM's standard pattern: input processing → scheduling → worker execution → output processing, with omni-specific enrichments at each stage.
### OmniARScheduler
`OmniARScheduler` extends the base vLLM scheduler with minimal modifications, focusing on enriching scheduled requests with omni-specific payloads.
#### Modified API: `schedule()`
The scheduler wraps base `NewRequestData` entries with `OmniNewRequestData` to include prompt embeddings and additional information:
```python
def schedule(self) -> SchedulerOutput:
scheduler_output = super().schedule()
# Rewrap base NewRequestData entries with OmniNewRequestData
new_list = []
for nr in scheduler_output.scheduled_new_reqs:
request = self.requests.get(nr.req_id)
omni_nr = OmniNewRequestData(
req_id=nr.req_id,
prompt_token_ids=nr.prompt_token_ids,
# ... other base fields ...
prompt_embeds=getattr(request, "prompt_embeds", None),
additional_information=getattr(request, "additional_information", None),
)
new_list.append(omni_nr)
scheduler_output.scheduled_new_reqs = new_list
return scheduler_output
```
The `update_from_output()` method remains unchanged, inheriting standard request lifecycle management from the base scheduler.
### OmniGenerationScheduler
`OmniGenerationScheduler` implements a fast-path scheduling strategy for basic heterogeneous architectures that process all input tokens in a single step.
#### Modified API: `schedule()`
Allocates all input tokens for a request at once (or 1 placeholder if zero), falling back to default scheduling if budget is insufficient:
```python
def schedule(self) -> SchedulerOutput:
# Fast path: allocate all input tokens at once
while self.waiting and token_budget > 0:
request = self.waiting.peek_request()
required_tokens = max(getattr(request, "num_prompt_tokens", 0), 1)
if required_tokens > token_budget:
break # Fall back to default scheduling
# Allocate and schedule...
```
#### Modified API: `update_from_output()`
Marks requests as finished immediately after one step, since generation models complete in a single forward pass:
```python
def update_from_output(self, ...) -> dict[int, EngineCoreOutputs]:
# ...
# Diffusion request: completes in one step
request.status = RequestStatus.FINISHED_STOPPED
kv_transfer_params = self._free_request(request)
# ...
```
## 4. Worker and ModelRunner Design
### GPUARWorker
`GPUARWorker` initializes the AR-specific model runner while maintaining standard device initialization:
```python
class GPUARWorker(GPUWorker):
def init_device(self):
# ... standard device initialization ...
self.model_runner = GPUARModelRunner(self.vllm_config, self.device)
```
### GPUARModelRunner
`GPUARModelRunner` follows vLLM's two-phase execute/sample flow while exposing hidden states and multimodal outputs.
#### Two-Phase Execution
**Phase 1: `execute_model()`** - Runs forward pass and stores state:
- Computes logits from hidden states
- Stores `ExecuteModelState` with hidden states, logits, and multimodal outputs
- Returns `None` to defer sampling
**Phase 2: `sample_tokens()`** - Samples tokens and builds output:
- Retrieves stored state from `execute_model()`
- Samples tokens using logits
- Extracts per-request hidden states and multimodal outputs
- Builds `OmniModelRunnerOutput` with `pooler_output` containing hidden states
```python
def sample_tokens(self, grammar_output) -> OmniModelRunnerOutput:
# Retrieve stored state
hidden_states, multimodal_outputs = self.execute_model_state
# Sample tokens
sampler_output = self._sample(logits, spec_decode_metadata)
# Extract per-request hidden states
pooler_output = []
for rid in req_ids:
hidden_slice = hidden_states_cpu[start:end]
payload = {"hidden": hidden_slice}
# Add multimodal outputs if present
pooler_output.append(payload)
return OmniModelRunnerOutput(
pooler_output=pooler_output,
# ... other fields ...
)
```
### GPUGenerationModelRunner
`GPUGenerationModelRunner` implements a simplified single-phase execution for basic heterogeneous architectures:
- No logits computation or token sampling
- Direct generation from forward pass in model implementation
- Returns outputs via `pooler_output` immediately after forward pass
### OmniGPUModelRunner
`OmniGPUModelRunner` provides shared functionality for both AR and Generation runners:
#### Prompt Embeddings Overlay
During prefill, overlays custom `prompt_embeds` from request state onto `inputs_embeds`:
```python
def _collect_additional_information_for_prefill(self, num_scheduled_tokens_np):
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
pe_cpu = getattr(req_state, "prompt_embeds_cpu", None)
# Overlay prompt_embeds for prefill portion
if pe_cpu is not None:
src = pe_cpu[num_computed_tokens:num_computed_tokens + overlay_len]
self.inputs_embeds[start_offset:start_offset + overlay_len].copy_(src)
```
#### Additional Information Processing
Decodes and manages `additional_information` payloads:
- Decodes serialized payloads → CPU tensors in request state
- Passes runtime information to model via `runtime_additional_information` kwarg
- Processes model-provided updates via `postprocess()` hook
- Merges updates back into request state
#### M-RoPE Position Initialization
For multimodal models using M-RoPE (e.g., Qwen2-VL), computes position encodings from multimodal feature metadata (image grids, video grids, audio features).
## 5. Input/Output Processing
### Processing Pipeline
The input/output processing pipeline handles serialization, routing, and accumulation of multimodal data:
```mermaid
sequenceDiagram
participant Client
participant OmniInputProcessor
participant Scheduler
participant ModelRunner
participant MultimodalOutputProcessor
participant Client
Client->>OmniInputProcessor: prompt + prompt_embeds + additional_info
OmniInputProcessor->>OmniInputProcessor: Serialize tensors to payloads
OmniInputProcessor->>Scheduler: OmniEngineCoreRequest (with payloads)
Scheduler->>ModelRunner: OmniNewRequestData (with payloads)
ModelRunner->>ModelRunner: Decode payloads → CPU tensors
ModelRunner->>ModelRunner: Overlay prompt_embeds on inputs_embeds
ModelRunner->>ModelRunner: Forward pass with runtime_additional_information
ModelRunner->>ModelRunner: Extract hidden states + multimodal outputs
ModelRunner->>MultimodalOutputProcessor: OmniModelRunnerOutput (pooler_output)
MultimodalOutputProcessor->>MultimodalOutputProcessor: Route by output_type
MultimodalOutputProcessor->>MultimodalOutputProcessor: Accumulate tensors in OmniRequestState
MultimodalOutputProcessor->>MultimodalOutputProcessor: Consolidate tensor lists
MultimodalOutputProcessor->>Client: RequestOutput (with multimodal_output)
```
### OmniInputProcessor
`OmniInputProcessor` extends the base input processor to serialize prompt embeddings and additional information for inter-stage transfer.
#### Payload Serialization
Converts PyTorch tensors to serialized payloads:
```python
def process_inputs(self, ...) -> OmniEngineCoreRequest:
# Serialize prompt_embeds
if "prompt_embeds" in decoder_inputs:
pe_cpu = decoder_inputs["prompt_embeds"].detach().to("cpu").contiguous()
prompt_embeds_payload = PromptEmbedsPayload(
data=pe_cpu.numpy().tobytes(),
shape=[seq_len, hidden_size],
dtype=dtype_str,
)
# Serialize additional_information
if "additional_information" in decoder_inputs:
entries = {}
for key, value in raw_info.items():
if isinstance(value, torch.Tensor):
entry = AdditionalInformationEntry(
tensor_data=value.numpy().tobytes(),
tensor_shape=list(value.shape),
tensor_dtype=dtype_str,
)
entries[key] = entry
additional_information_payload = AdditionalInformationPayload(entries=entries)
return OmniEngineCoreRequest(
# ... standard fields ...
prompt_embeds=prompt_embeds_payload,
additional_information=additional_information_payload,
)
```
### MultimodalOutputProcessor
`MultimodalOutputProcessor` routes outputs by modality type and accumulates multimodal tensors.
#### Output Routing
Routes `EngineCoreOutput` by `output_type` attribute:
- `"text"`: Standard text generation path
- `"image"`, `"audio"`, `"latents"`: Extract from `pooling_output` or `multimodal_outputs`
- Fallback: Heuristic based on presence of `pooling_output`
#### Tensor Accumulation
`OmniRequestState` accumulates multimodal tensors across multiple steps:
```python
def add_multimodal_tensor(self, payload, mm_type):
# Normalize payload to dict
incoming = {mm_type or "hidden": payload}
# Accumulate: convert tensors to lists for deferred concatenation
if isinstance(v, torch.Tensor) and isinstance(existing, torch.Tensor):
self.mm_accumulated[k] = [existing, v] # List accumulation
```
Before final output, consolidates tensor lists via concatenation:
```python
def _consolidate_multimodal_tensors(self):
for k, v in self.mm_accumulated.items():
if isinstance(v, list) and isinstance(v[0], torch.Tensor):
self.mm_accumulated[k] = torch.cat(v, dim=0) # Concatenate
```
The consolidated tensors are attached to `RequestOutput.multimodal_output` for consumption by downstream stages or clients.
## 6. Summary
The AR module of vLLM-Omni extends vLLM through strategic inheritance and minimal API modifications:
### Key Design Patterns
1. **Inheritance over composition**: Extends vLLM classes to preserve compatibility with existing scheduling, batching, and execution mechanisms
2. **Payload serialization**: Uses `PromptEmbedsPayload` and `AdditionalInformationPayload` for efficient inter-stage data transfer
3. **Two-phase execution**: Maintains vLLM's execute/sample separation for AR models while supporting single-phase execution for generation models
4. **Multimodal routing**: Routes outputs by `output_type` and accumulates tensors incrementally to support streaming
### Differences from vLLM
- **Payload support**: Serialized prompt embeddings and additional information enable direct transfer between pipeline stages
- **Multimodal handling**: Extended input/output processors support images, audio, and other modalities alongside text
- **Hidden state exposure**: AR model runners expose per-request hidden states via `pooler_output` for downstream consumption
- **Generation scheduler**: Fast-path scheduling for basic heterogeneous architectures that complete in one step
The AR module seamlessly integrates with vLLM's existing infrastructure while adding the necessary extensions for multi-stage, multimodal generation pipelines.
---
toc_depth: 4
---
# Diffusion Module Architecture Design
The vLLM-Omni diffusion module (`vllm_omni/diffusion`) is a high-performance inference engine for diffusion models, designed with a modular architecture that separates concerns across multiple components. It provides efficient execution for non-autoregressive generation tasks such as image and video generation.
This document describes the architecture design of the diffusion module, including the diffusion engine, scheduler, worker, diffusion pipeline, and acceleration components.
<p align="center">
<img src="https://github.com/user-attachments/assets/cde1ebea-4006-4d30-8283-15a5721a09d2" alt="vLLM-Omni Diffusion Module Components" width="80%">
</p>
<p align="center">
<em> Main Components of the Diffusion Module </em>
</p>
**Table of Content:**
- [Architecture Overview](#architecture-overview)
- [Diffusion Engine](#1-diffusion-engine)
- [Scheduler](#2-scheduler)
- [Worker](#3-worker)
- [Diffusion Pipeline](#4-diffusion-pipeline)
- [Acceleration Components](#5-acceleration-components)
- [Attention Backends](#51-attention-backends)
- [Parallel Attention](#52-parallel-attention)
- [Cache Backends](#53-cache-backends)
- [Parallel Strategies](#54-parallel-strategies)
- [Data Flow](#6-data-flow)
---
## Architecture Overview
The diffusion module follows a **multi-process, distributed architecture** with clear separation of concerns:
<p align="center">
<img src="https://github.com/user-attachments/assets/78c4d446-d238-4406-a057-eb17d2ccc8e0" alt="vLLM-Omni Diffusion Module Architecture" width="100%">
</p>
<p align="center">
<em> Diffusion Architecture Overview </em>
</p>
---
## 1. Diffusion Engine
**Location**: `vllm_omni/diffusion/diffusion_engine.py`
### Responsibilities
The `DiffusionEngine` is the **orchestrator** of the diffusion inference system. It manages the lifecycle of worker processes and coordinates the execution flow.
### Key Components
#### 1.1 Initialization
```python
class DiffusionEngine:
def __init__(self, od_config: OmniDiffusionConfig):
self.od_config = od_config
self.post_process_func = get_diffusion_post_process_func(od_config)
self.pre_process_func = get_diffusion_pre_process_func(od_config)
self._processes: list[mp.Process] = []
self._make_client()
```
**Key Features**:
- **Pre/Post Processing**: Registers model-specific pre-processing and post-processing functions via registry pattern
- **Worker Management**: Launches and manages multiple worker processes (one per GPU)
- **Process Isolation**: Uses multiprocessing for true parallelism
#### 1.2 Worker Launch Process
The engine launches workers using a **spawn** method:
```python
def _launch_workers(self, broadcast_handle):
# Creates one process per GPU
for i in range(num_gpus):
process = mp.Process(
target=worker_proc.worker_main,
args=(i, od_config, writer, broadcast_handle),
name=f"DiffusionWorker-{i}",
)
process.start()
```
**Design Decisions**:
- **Spawn Method**: Ensures clean state for each worker (no shared memory issues)
- **Pipe Communication**: Uses `mp.Pipe` for initialization handshake
- **Device Selection**: Each worker is assigned a specific GPU (`cuda:{rank}`)
#### 1.3 Request Processing Flow
```python
def step(self, requests: list[OmniDiffusionRequest]):
# 1. Pre-process requests
requests = self.pre_process_func(requests)
# 2. Send to scheduler and wait for response
output = self.add_req_and_wait_for_response(requests)
# 3. Post-process results
result = self.post_process_func(output.output)
return result
```
**Flow**:
1. **Pre-processing**: Applies model-specific transformations
2. **Scheduling**: Delegates to scheduler for distribution
3. **Post-processing**: Converts raw outputs to final format (e.g., PIL images)
---
## 2. Scheduler
**Location**: `vllm_omni/diffusion/scheduler.py`
### Architecture
The `Scheduler` is implemented as a **Singleton** pattern to ensure a single coordination point across the system, i.e., only one scheduler instance exists for coordination.
### Key Components
#### 2.1 Message Queue System
```python
class Scheduler:
def initialize(self, od_config: OmniDiffusionConfig):
# Broadcast queue: scheduler -> all workers
self.mq = MessageQueue(
n_reader=od_config.num_gpus,
n_local_reader=od_config.num_gpus,
local_reader_ranks=list(range(od_config.num_gpus)),
)
# Result queue: rank 0 worker -> scheduler
self.result_mq = None # Initialized later
```
**Communication Pattern**:
- **Broadcast Queue**: One-to-many communication (scheduler → all workers)
- **Result Queue**: One-to-one communication (rank 0 → scheduler)
- **Shared Memory**: Uses `MessageQueue` (ZMQ-based) for efficient IPC
#### 2.2 Request Distribution
```python
def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
# Broadcast request to all workers
self.mq.enqueue(requests)
# Wait for result from Rank 0
output = self.result_mq.dequeue()
return output
```
**Design Features**:
- **Broadcast Model**: All workers receive the same request (for tensor parallelism)
- **Single Response**: Only rank 0 sends results back (avoids duplicate outputs)
- **Synchronous**: Blocks until result is received (can be made async)
#### 2.3 Singleton Pattern
```python
class Scheduler:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
# Global singleton instance
scheduler = Scheduler()
```
**Benefits**:
- **Single Point of Control**: Ensures consistent state
- **Easy Access**: Global `scheduler` instance accessible everywhere
- **Resource Management**: Centralized queue management
---
## 3. Worker
**Location**: `vllm_omni/diffusion/worker/gpu_worker.py`
### Architecture
Workers are **independent processes** that execute the actual model inference. Each worker runs on a dedicated GPU and participates in distributed inference.
### Key Components
#### 3.1 Worker Process Structure
```python
class WorkerProc:
def __init__(self, od_config, gpu_id, broadcast_handle):
# Initialize ZMQ context for IPC
self.context = zmq.Context(io_threads=2)
# Connect to broadcast queue (receive requests)
self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id)
# Create result queue (only rank 0)
if gpu_id == 0:
self.result_mq = MessageQueue(n_reader=1, ...)
# Initialize GPU worker
self.worker = GPUWorker(local_rank=gpu_id, rank=gpu_id, od_config=od_config)
```
**Initialization Steps**:
1. **IPC Setup**: Creates ZMQ context and message queues
2. **Distributed Environment Setup**: Initializes PyTorch distributed communication
- For CUDA GPUs: Uses NCCL (fast GPU communication)
- For NPU: Uses HCCL (Huawei Collective Communications Library)
- For other devices: Uses appropriate backend (GLOO, MCCL, etc.)
3. **Model Loading**: Loads diffusion pipeline on assigned GPU
4. **Cache Setup**: Enables cache backend if configured.
#### 3.2 GPU Worker
```python
class GPUWorker:
def init_device_and_model(self):
# Set distributed environment variables
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Initialize PyTorch distributed
init_distributed_environment(world_size, rank)
parallel_config = self.od_config.parallel_config
initialize_model_parallel(
data_parallel_size=parallel_config.data_parallel_size,
cfg_parallel_size=parallel_config.cfg_parallel_size,
sequence_parallel_size=parallel_config.sequence_parallel_size,
tensor_parallel_size=parallel_config.tensor_parallel_size,
pipeline_parallel_size=parallel_config.pipeline_parallel_size,
)
# Load model
model_loader = DiffusersPipelineLoader(load_config)
self.pipeline = model_loader.load_model(od_config, load_device=f"cuda:{rank}")
# Setup cache backend
from vllm_omni.diffusion.cache.selector import get_cache_backend
self.cache_backend = get_cache_backend(od_config.cache_backend, od_config.cache_config)
if self.cache_backend is not None:
self.cache_backend.enable(self.pipeline)
```
**Key Features**:
- **Tensor Parallelism**: Supports multi-GPU tensor parallelism via PyTorch distributed
- **Model Loading**: Uses `DiffusersPipelineLoader` for efficient weight loading
- **Cache Integration**: Enables cache backends (TeaCache, cache-dit, etc.) transparently
#### 3.3 Worker Busy Loop
```python
def worker_busy_loop(self):
while self._running:
# 1. Receive unified message (generation request, RPC request, or shutdown)
msg = self.recv_message()
# 2. Route message based on type
if isinstance(msg, dict) and msg.get("type") == "rpc":
# Handle RPC request
result, should_reply = self.execute_rpc(msg)
if should_reply:
self.return_result(result)
elif isinstance(msg, dict) and msg.get("type") == "shutdown":
# Handle shutdown message
self._running = False
else:
# Handle generation request (OmniDiffusionRequest list)
output = self.worker.execute_model(msg, self.od_config)
self.return_result(output)
```
**Execution Flow**:
1. **Receive**: Dequeues unified messages from shared memory queue
2. **Route**: Handles different message types (generation, RPC, shutdown)
3. **Execute**: Runs forward pass through pipeline for generation requests
4. **Respond**: Sends results back (rank 0 for generation, specified rank for RPC)
#### 3.4 Model Execution
```python
@torch.inference_mode()
def execute_model(self, reqs: list[OmniDiffusionRequest], od_config):
req = reqs[0] # TODO: support batching
# Refresh cache backend if enabled
if self.cache_backend is not None and self.cache_backend.is_enabled():
self.cache_backend.refresh(self.pipeline, req.num_inference_steps)
# Set forward context for parallelism
with set_forward_context(
vllm_config=self.vllm_config,
omni_diffusion_config=self.od_config
):
output = self.pipeline.forward(req)
return output
```
The model execution leverages multiple parallelism strategies that are transparently applied during the forward pass. The `set_forward_context()` context manager makes parallel group information available throughout the forward pass:
```python
# Inside transformer layers, parallel groups are accessed via:
from vllm_omni.diffusion.distributed.parallel_state import (
get_sp_group, get_dp_group, get_cfg_group, get_pp_group
)
```
**Optimizations**:
- **Cache Refresh**: Clears cache state before each generation for clean state
- **Context Management**: Forward context ensures parallel groups are available during execution
- **Single Request**: Currently processes one request at a time (batching TODO)
---
## 4. Diffusion Pipeline
**Location**: `vllm_omni/diffusion/models/*/pipeline_*.py`
The pipeline is the **model-specific implementation** that orchestrates the diffusion process. Different models (QwenImage, Wan2.2, Z-Image) have their own pipeline implementations.
Most pipeline implementation are referred from `diffusers`. The multi-step diffusion loop is usually the most time-consuming part during the overall inference process, which is defined by the `diffuse` function in the pipeline class. An example is as follows:
```python
def diffuse(self, ...):
for i, t in enumerate(timesteps):
# Forward pass for positive prompt
transformer_kwargs = {
"hidden_states": latents,
"timestep": timestep / 1000,
"encoder_hidden_states": prompt_embeds,
}
noise_pred = self.transformer(**transformer_kwargs)[0]
# Forward pass for negative prompt (CFG)
if do_true_cfg:
neg_transformer_kwargs = {...}
neg_transformer_kwargs["cache_branch"] = "negative"
neg_noise_pred = self.transformer(**neg_transformer_kwargs)[0]
# Combine predictions
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
noise_pred = comb_pred * (cond_norm / noise_norm)
# Scheduler step
latents = self.scheduler.step(noise_pred, t, latents)[0]
return latents
```
**Key Features**:
- **CFG Support**: Handles classifier-free guidance with separate forward passes
- **Cache Branching**: Uses `cache_branch` parameter for cache-aware execution
- **True CFG**: Implements advanced CFG with norm preservation
To learn more about the diffusion pipeline and how to add a new diffusion pipeline, please view [Adding Diffusion Model](https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/model/adding_diffusion_model)
---
## 5. Acceleration Components
### 5.1 Attention Backends
**Location**: `vllm_omni/diffusion/attention/`
#### Architecture
The attention system uses a **backend selector pattern** that automatically chooses the optimal attention implementation based on hardware and model configuration.
#### Backend Selection
**Location**: `vllm_omni/diffusion/attention/selector.py`
```python
class Attention(nn.Module):
def __init__(self, num_heads, head_size, causal, softmax_scale, ...):
# Auto-select backend
self.attn_backend = get_attn_backend(-1)
self.attn_impl_cls = self.attn_backend.get_impl_cls()
self.attention = self.attn_impl_cls(...)
```
**Available Backends**:
- **FlashAttention**: Optimized CUDA kernel (FA2/FA3) - memory efficient via tiling
- **SDPA**: PyTorch's scaled dot-product attention - default, cross-platform
- **SageAttention**: Sparse attention implementation from SageAttention library
- **AscendAttention**: NPU-optimized attention for Ascend hardware
These backends provide the **kernel implementations** for attention computation. For attention-level sequence parallelism strategies (Ring Attention, Ulysses), see [Parallel Attention](#52-parallel-attention).
#### Backend Selection Mechanism
```python
def get_attn_backend(head_size: int) -> type[AttentionBackend]:
# Check environment variable
backend_name = os.environ.get("DIFFUSION_ATTENTION_BACKEND")
if backend_name:
return load_backend(backend_name.upper())
# Default to SDPA
return SDPABackend
```
**Selection Priority**:
1. **Environment Variable**: `DIFFUSION_ATTENTION_BACKEND` for manual override
- Valid values: `FLASH_ATTN`, `TORCH_SDPA`, `SAGE_ATTN`, `ASCEND`
- Example: `export DIFFUSION_ATTENTION_BACKEND=SAGE_ATTN`
2. **Automatic Fallback**: Falls back to SDPA if selected backend unavailable
3. **Hardware Detection**: Can select based on device type (NPU, CUDA, etc.)
**Backend Availability**:
- **SDPA**: Always available (PyTorch built-in)
- **FlashAttention**: Requires `flash-attn` package installed
- **SageAttention**: Requires `sage-attention` package (from THU-ML GitHub)
- **AscendAttention**: Only available on Ascend NPU hardware
#### Attention Backend Registry
**Location**: `vllm_omni/diffusion/attention/selector.py`
The attention system uses a **registry pattern** to manage and dynamically load attention backends. This allows for easy extension and runtime selection of backends.
**Registry Structure**:
```python
# Registry mapping backend names to their module paths and class names
_BACKEND_CONFIG = {
"FLASH_ATTN": {
"module": "vllm_omni.diffusion.attention.backends.flash_attn",
"class": "FlashAttentionBackend",
},
"TORCH_SDPA": {
"module": "vllm_omni.diffusion.attention.backends.sdpa",
"class": "SDPABackend",
},
"SAGE_ATTN": {
"module": "vllm_omni.diffusion.attention.backends.sage_attn",
"class": "SageAttentionBackend",
},
"ASCEND": {
"module": "vllm_omni.diffusion.attention.backends.ascend_attn",
"class": "AscendAttentionBackend",
},
}
```
#### Attention Backend Integration
The `Attention` layer integrates backends through a unified interface. Here's how **FlashAttentionBackend** is integrated as an example:
```python
# attention/backends/flash_attn.py
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 96, 128, 192, 256] # FlashAttention supports these head sizes
class FlashAttentionImpl(AttentionImpl):
def __init__(self, num_heads, head_size, softmax_scale, causal, ...):
self.num_heads = num_heads
self.causal = causal
self.softmax_scale = softmax_scale
def forward(self, query, key, value, attn_metadata=None):
# Call FlashAttention kernel
out = flash_attn_func(
query, key, value,
causal=self.causal,
softmax_scale=self.softmax_scale,
)
return out
```
---
### 5.2 Parallel Attention
**Location**: `vllm_omni/diffusion/attention/parallel/`
#### Architecture
Parallel attention strategies implement **Sequence Parallelism (SP) at the attention layer level**. These strategies distribute attention computation across multiple GPUs by splitting the sequence dimension, using different communication patterns. They work **on top of** AttentionBackend implementations (FlashAttention, SDPA, etc.), handling the parallelization/communication while the backends handle the actual attention computation.
**Key Distinction**: Unlike AttentionBackend (which provides kernel implementations), ParallelAttentionStrategy provides communication patterns for multi-GPU attention parallelism. These strategies implement the `ParallelAttentionStrategy` interface and use AttentionBackend implementations internally.
Both Ring Attention and Ulysses are forms of Sequence Parallelism (SP) that:
- Split the sequence dimension across GPUs
- Contribute to `sequence_parallel_size` (via `ring_degree` and `ulysses_degree`)
- Work at the attention layer level (not model/pipeline level)
#### Ulysses Sequence Parallelism (USP)
**Location**: `vllm_omni/diffusion/attention/parallel/ulysses.py`
USP is a sequence-parallel attention strategy that splits attention computation across multiple GPUs by distributing both the sequence dimension and attention heads. It uses **all-to-all communication** to efficiently parallelize attention for very long sequences. Specifically, it uses **all-to-all** collective operations to redistribute Q/K/V tensors before attention computation and gather results afterward.
Ulysses splits attention computation in two dimensions:
1. **Sequence Dimension**: Splits the sequence length across GPUs
2. **Head Dimension**: Splits attention heads across GPUs
**Configuration**: `ulysses_degree` contributes to `sequence_parallel_size`
#### Ring Sequence Parallelism
**Location**: `vllm_omni/diffusion/attention/parallel/ring.py`
Ring Attention is a **parallel attention strategy** that implements sequence parallelism using ring-based point-to-point (P2P) communication. Unlike attention backends that provide the attention kernel implementation, Ring Attention is a **communication pattern** that works on top of attention backends (FlashAttention or SDPA).
Ring Attention splits sequence dimension across GPUs in a ring topology, implemented via the `ParallelAttentionStrategy` interface, instead of `AttentionBackend`. P2P ring communication is applied to circulate Key/Value blocks across GPUs. Internally, `ring_flash_attn_func` or `ring_pytorch_attn_func` is used depending on available backends.
**Architecture**:
```python
class RingParallelAttention:
"""Ring sequence-parallel strategy."""
def run_attention(self, query, key, value, attn_metadata, ...):
# Selects underlying attention kernel (FlashAttention or SDPA)
if backend_pref == "sdpa":
return ring_pytorch_attn_func(...) # Uses SDPA kernel
else:
return ring_flash_attn_func(...) # Uses FlashAttention kernel
```
**Integration**:
- Ring Attention is activated when `ring_degree > 1` in parallel config
- It's selected by `build_parallel_attention_strategy()` in the attention layer
- The `Attention` layer routes to `_run_ring_attention()` when Ring is enabled
- Works alongside attention backends: Ring handles communication, backends handle computation
**Configuration**: `ring_degree` contributes to `sequence_parallel_size`
#### Relationship with AttentionBackend
Parallel attention strategies (Ring, Ulysses) work **on top of** AttentionBackend implementations:
- They use AttentionBackend for the actual attention computation (FlashAttention, SDPA, etc.)
- They handle the multi-GPU communication/parallelization layer
- They implement `ParallelAttentionStrategy` interface (not `AttentionBackend`)
For general parallelism strategies (Data Parallelism, Tensor Parallelism, Pipeline Parallelism), see [Parallel Strategies](#54-parallel-strategies).
---
### 5.3 Cache Backends
**Location**: `vllm_omni/diffusion/cache/`
#### Architecture
Cache backends provide a **unified interface** for applying different caching strategies to accelerate diffusion inference. The system supports multiple backends (TeaCache, cache-dit) with a consistent API for enabling and refreshing cache state.
#### Cache Backend Interface
```python
class CacheBackend(ABC):
def __init__(self, config: DiffusionCacheConfig):
self.config = config
self.enabled = False
@abstractmethod
def enable(self, pipeline: Any) -> None:
"""Enable cache on the pipeline."""
raise NotImplementedError
@abstractmethod
def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache state for new generation."""
raise NotImplementedError
def is_enabled(self) -> bool:
"""Check if cache is enabled."""
return self.enabled
```
**Design Pattern**:
- **Abstract Base Class**: Defines contract for all cache backends
- **Pipeline-based**: Works with pipeline instances (not just transformers)
- **State Management**: Provides refresh mechanism for clean state between generations
#### Available Backends
**1. TeaCache Backend**
**Location**: `vllm_omni/diffusion/cache/teacache/backend.py`
```python
class TeaCacheBackend(CacheBackend):
def enable(self, pipeline: Any):
# Extract transformer from pipeline
transformer = pipeline.transformer
transformer_type = transformer.__class__.__name__
# Create TeaCacheConfig from DiffusionCacheConfig
teacache_config = TeaCacheConfig(
transformer_type=transformer_type,
rel_l1_thresh=self.config.rel_l1_thresh,
coefficients=self.config.coefficients,
)
# Apply hooks to transformer
apply_teacache_hook(transformer, teacache_config)
self.enabled = True
def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True):
transformer = pipeline.transformer
if hasattr(transformer, "_hook_registry"):
transformer._hook_registry.reset_hook(TeaCacheHook._HOOK_NAME)
```
**TeaCache Features**:
- **Timestep-aware**: Caches based on timestep embedding similarity
- **Adaptive**: Dynamically decides when to reuse cached computations
- **CFG-aware**: Handles positive/negative branches separately
- **Custom Hook System**: Uses a custom forward interception mechanism (via `HookRegistry`) that wraps the module's `forward` method, allowing transparent integration without modifying model code
**2. Cache-DiT Backend**
**Location**: `vllm_omni/diffusion/cache/cache_dit_backend.py`
```python
class CacheDiTBackend(CacheBackend):
def enable(self, pipeline: Any):
# Uses cache-dit library for acceleration
# Supports DBCache, SCM (Step Computation Masking), TaylorSeer
# Works with single and dual-transformer architectures
...
self.enabled = True
def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True):
# Updates cache context with new num_inference_steps
...
```
**Cache-DiT Features**:
- **DBCache**: Dynamic block caching with configurable compute blocks
- **SCM**: Step Computation Masking for additional speedup
- **TaylorSeer**: Advanced calibration for cache accuracy
- **Dual-transformer Support**: Handles models like Wan2.2 with two transformers
#### Cache Backend Selector
**Location**: `vllm_omni/diffusion/cache/selector.py`
```python
def get_cache_backend(
cache_backend: str | None,
cache_config: dict | DiffusionCacheConfig
) -> CacheBackend | None:
"""Get cache backend instance based on cache_backend string.
Args:
cache_backend: Cache backend name ("cache_dit", "tea_cache", or None)
cache_config: Cache configuration (dict or DiffusionCacheConfig)
Returns:
Cache backend instance or None if cache_backend is "none"
"""
if cache_backend is None or cache_backend == "none":
return None
if isinstance(cache_config, dict):
cache_config = DiffusionCacheConfig.from_dict(cache_config)
if cache_backend == "cache_dit":
return CacheDiTBackend(cache_config)
elif cache_backend == "tea_cache":
return TeaCacheBackend(cache_config)
else:
raise ValueError(f"Unsupported cache backend: {cache_backend}")
```
**Usage Flow**:
1. **Selection**: `get_cache_backend()` returns appropriate backend instance
2. **Enable**: `backend.enable(pipeline)` called during worker initialization
3. **Refresh**: `backend.refresh(pipeline, num_inference_steps)` called before each generation
4. **Check**: `backend.is_enabled()` verifies cache is active
### 5.4 Parallel Strategies
**Location**: `vllm_omni/diffusion/distributed/parallel_state.py`
#### Parallelism Types
The system supports multiple orthogonal parallelism strategies:
**Sequence Parallelism (SP)**
- **Purpose**: Split sequence dimension across GPUs
- **Attention-level SP**: Ring Attention and Ulysses (USP) implement SP at the attention layer level
- See [Parallel Attention](#52-parallel-attention) for details
- Configuration: `ulysses_degree` × `ring_degree` = `sequence_parallel_size`
- **Use Case**: Very long sequences (e.g., high-resolution images)
**Data Parallelism (DP)**
- **Purpose**: Replicate model across GPUs, split batch
- **Use Case**: Batch processing, throughput optimization
**Tensor Parallelism (TP)** (Experimental)
- **Purpose**: Split model weights across GPUs
- **Implementation**: Uses vLLM's tensor parallel groups
- **Use Case**: Large models that don't fit on single GPU
**CFG Parallelism** (under development)
- **Purpose**: Parallelize Classifier-Free Guidance (positive/negative prompts)
- **Infrastructure**: CFG parallel groups are initialized and available via `get_cfg_group()`
#### Parallel Group Management
```python
def initialize_model_parallel(
data_parallel_size: int = 1,
cfg_parallel_size: int = 1,
sequence_parallel_size: int | None = None,
ulysses_degree: int = 1,
ring_degree: int = 1,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
vae_parallel_size: int = 0,
):
# Generate orthogonal parallel groups
rank_generator = RankGenerator(
tensor_parallel_size,
sequence_parallel_size,
pipeline_parallel_size,
cfg_parallel_size,
data_parallel_size,
"tp-sp-pp-cfg-dp",
)
# Initialize each parallel group
_DP = init_model_parallel_group(rank_generator.get_ranks("dp"), ...)
_CFG = init_model_parallel_group(rank_generator.get_ranks("cfg"), ...)
_SP = init_model_parallel_group(rank_generator.get_ranks("sp"), ...)
_PP = init_model_parallel_group(rank_generator.get_ranks("pp"), ...)
_TP = init_model_parallel_group(rank_generator.get_ranks("tp"), ...)
```
**Rank Order**: `tp-sp-pp-cfg-dp` (tensor → sequence → pipeline → cfg → data)
**Note**: For attention-level Sequence Parallelism implementations (Ring Attention and Ulysses), see [Parallel Attention](#52-parallel-attention). This section covers higher-level parallelism strategies.
---
## 6. Data Flow
### Complete Request Flow
<p align="center">
<img src="https://github.com/user-attachments/assets/6e093a0d-29c0-4efd-85d5-747002bd2fed" alt="vLLM-Omni Diffusion Module Components" width="100%">
</p>
<p align="center">
<em> End-to-end Data Flow in the vLLM-Omni Diffusion Module </em>
</p>
```
1. User Request
└─> OmniDiffusion.generate(prompt)
└─> Prepare OmniDiffusionRequest
└─> DiffusionEngine.step(requests)
2. Pre-processing
└─> pre_process_func(requests)
└─> Model-specific transformations
3. Scheduling
└─> scheduler.add_req(requests)
└─> Broadcast via MessageQueue to all workers
4. Worker Execution
└─> WorkerProc.worker_busy_loop()
└─> GPUWorker.execute_model(reqs)
└─> Pipeline.forward(req)
├─> encode_prompt()
├─> prepare_latents()
├─> diffuse() [loop]
│ ├─> transformer.forward() [with cache backend hooks]
│ └─> scheduler.step()
└─> vae.decode()
5. Result Collection
└─> Rank 0 sends DiffusionOutput via result queue
└─> Scheduler receives and returns
6. Post-processing
└─> post_process_func(output)
└─> Convert to PIL images / final format
```
---
Architecture design of the entrypoint (update soon)
# Examples
vLLM-Omni's examples are split into two categories:
- If you are using vLLM-Omni from within Python code, see the *Offline Inference* section.
- If you are using vLLM-Omni from an HTTP application or client, see the *Online Serving* section.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment