diff --git a/.buildkite/bootstrap-amd-omni.sh b/.buildkite/bootstrap-amd-omni.sh new file mode 100644 index 0000000000000000000000000000000000000000..a38b76220110a0b97e951f0e0d18295c5d6e250d --- /dev/null +++ b/.buildkite/bootstrap-amd-omni.sh @@ -0,0 +1,238 @@ +#!/bin/bash +# vllm-omni customized version +# Based on: https://github.com/vllm-project/ci-infra/blob/main/buildkite/bootstrap-amd.sh +# Last synced: 2025-12-15 +# Modifications: Use local template file instead of downloading from ci-infra + +set -euo pipefail + +if [[ -z "${RUN_ALL:-}" ]]; then + RUN_ALL=0 +fi + +if [[ -z "${NIGHTLY:-}" ]]; then + NIGHTLY=0 +fi + +if [[ -z "${VLLM_CI_BRANCH:-}" ]]; then + VLLM_CI_BRANCH="main" +fi + +if [[ -z "${AMD_MIRROR_HW:-}" ]]; then + AMD_MIRROR_HW="amdproduction" +fi + +if [[ -z "${DOCS_ONLY_DISABLE:-}" ]]; then + DOCS_ONLY_DISABLE=0 +fi + +fail_fast() { + DISABLE_LABEL="ci-no-fail-fast" + # If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq + if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then + PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm-omni/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name') + if [[ $PR_LABELS == *"$DISABLE_LABEL"* ]]; then + echo false + else + echo true + fi + else + echo false # not a PR or BUILDKITE_PULL_REQUEST not set + fi +} + +check_run_all_label() { + RUN_ALL_LABEL="ready-run-all-tests" + # If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq + if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then + PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm-omni/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name') + if [[ $PR_LABELS == *"$RUN_ALL_LABEL"* ]]; then + echo true + else + echo false + fi + else + echo false # not a PR or BUILDKITE_PULL_REQUEST not set + fi +} + +if [[ -z "${COV_ENABLED:-}" ]]; then + COV_ENABLED=0 +fi + +upload_pipeline() { + echo "Uploading pipeline..." + # Install minijinja + ls .buildkite || buildkite-agent annotate --style error 'Please merge upstream main branch for buildkite CI' + curl -sSfL https://github.com/mitsuhiko/minijinja/releases/download/2.3.1/minijinja-cli-installer.sh | sh + source /var/lib/buildkite-agent/.cargo/env + + if [[ $BUILDKITE_PIPELINE_SLUG == "fastcheck" ]]; then + AMD_MIRROR_HW="amdtentative" + fi + + # Use local template file for vllm-omni + cp .buildkite/test-template-amd-omni.j2 .buildkite/test-template.j2 + + + # (WIP) Use pipeline generator instead of jinja template + if [ -e ".buildkite/pipeline_generator/pipeline_generator.py" ]; then + python -m pip install click pydantic + python .buildkite/pipeline_generator/pipeline_generator.py --run_all=$RUN_ALL --list_file_diff="$LIST_FILE_DIFF" --nightly="$NIGHTLY" --mirror_hw="$AMD_MIRROR_HW" + buildkite-agent pipeline upload .buildkite/pipeline.yaml + exit 0 + fi + echo "List file diff: $LIST_FILE_DIFF" + echo "Run all: $RUN_ALL" + echo "Nightly: $NIGHTLY" + echo "AMD Mirror HW: $AMD_MIRROR_HW" + + FAIL_FAST=$(fail_fast) + + cd .buildkite + ( + set -x + # Output pipeline.yaml with all blank lines removed + minijinja-cli test-template.j2 test-amd.yaml \ + -D branch="$BUILDKITE_BRANCH" \ + -D list_file_diff="$LIST_FILE_DIFF" \ + -D run_all="$RUN_ALL" \ + -D nightly="$NIGHTLY" \ + -D mirror_hw="$AMD_MIRROR_HW" \ + -D fail_fast="$FAIL_FAST" \ + -D vllm_use_precompiled="$VLLM_USE_PRECOMPILED" \ + -D vllm_merge_base_commit="$(git merge-base origin/main HEAD)" \ + -D cov_enabled="$COV_ENABLED" \ + -D vllm_ci_branch="$VLLM_CI_BRANCH" \ + | sed '/^[[:space:]]*$/d' \ + > pipeline.yaml + ) + cat pipeline.yaml + buildkite-agent artifact upload pipeline.yaml + buildkite-agent pipeline upload pipeline.yaml + exit 0 +} + +get_diff() { + $(git add .) + echo $(git diff --name-only --diff-filter=ACMDR $(git merge-base origin/main HEAD)) +} + +get_diff_main() { + $(git add .) + echo $(git diff --name-only --diff-filter=ACMDR HEAD~1) +} + +file_diff=$(get_diff) +if [[ $BUILDKITE_BRANCH == "main" ]]; then + file_diff=$(get_diff_main) +fi + +# ---------------------------------------------------------------------- +# Early exit start: skip pipeline if conditions are met +# ---------------------------------------------------------------------- + +# skip pipeline if all changed files are under docs/ +if [[ "${DOCS_ONLY_DISABLE}" != "1" ]]; then + if [[ -n "${file_diff:-}" ]]; then + docs_only=1 + # Robust iteration over newline-separated file_diff + while IFS= read -r f; do + [[ -z "$f" ]] && continue + # **Policy:** only skip if *every* path starts with docs/ + if [[ "$f" != docs/* ]]; then + docs_only=0 + break + fi + done < <(printf '%s\n' "$file_diff" | tr ' ' '\n' | tr -d '\r') + + if [[ "$docs_only" -eq 1 ]]; then + buildkite-agent annotate ":memo: CI skipped — docs/** only changes detected + +\`\`\` +${file_diff} +\`\`\`" --style "info" || true + echo "[docs-only] All changes are under docs/. Exiting before pipeline upload." + exit 0 + fi + fi +fi + +# ---------------------------------------------------------------------- +# Early exit end +# ---------------------------------------------------------------------- + +patterns=( + "docker/Dockerfile" + "CMakeLists.txt" + "requirements/common.txt" + "requirements/cuda.txt" + "requirements/build.txt" + "requirements/test.txt" + "setup.py" + "csrc/" + "cmake/" +) + +ignore_patterns=( + "docker/Dockerfile." + "csrc/cpu" + "csrc/rocm" + "cmake/hipify.py" + "cmake/cpu_extension.cmake" +) + +for file in $file_diff; do + # First check if file matches any pattern + matches_pattern=0 + for pattern in "${patterns[@]}"; do + if [[ $file == $pattern* ]] || [[ $file == $pattern ]]; then + matches_pattern=1 + break + fi + done + + # If file matches pattern, check it's not in ignore patterns + if [[ $matches_pattern -eq 1 ]]; then + matches_ignore=0 + for ignore in "${ignore_patterns[@]}"; do + if [[ $file == $ignore* ]] || [[ $file == $ignore ]]; then + matches_ignore=1 + break + fi + done + + if [[ $matches_ignore -eq 0 ]]; then + RUN_ALL=1 + echo "Found changes: $file. Run all tests" + break + fi + fi +done + +# Check for ready-run-all-tests label +LABEL_RUN_ALL=$(check_run_all_label) +if [[ $LABEL_RUN_ALL == true ]]; then + RUN_ALL=1 + NIGHTLY=1 + echo "Found 'ready-run-all-tests' label. Running all tests including optional tests." +fi + +# Decide whether to use precompiled wheels +# Relies on existing patterns array as a basis. +if [[ -n "${VLLM_USE_PRECOMPILED:-}" ]]; then + echo "VLLM_USE_PRECOMPILED is already set to: $VLLM_USE_PRECOMPILED" +elif [[ $RUN_ALL -eq 1 ]]; then + export VLLM_USE_PRECOMPILED=0 + echo "Detected critical changes, building wheels from source" +else + export VLLM_USE_PRECOMPILED=1 + echo "No critical changes, using precompiled wheels" +fi + + +LIST_FILE_DIFF=$(get_diff | tr ' ' '|') +if [[ $BUILDKITE_BRANCH == "main" ]]; then + LIST_FILE_DIFF=$(get_diff_main | tr ' ' '|') +fi +upload_pipeline diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml new file mode 100644 index 0000000000000000000000000000000000000000..402e625182f39d83b01104cccf23c6313c9af9e4 --- /dev/null +++ b/.buildkite/pipeline.yml @@ -0,0 +1,342 @@ +steps: + - label: ":docker: Build image" + key: image-build + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "docker build --file docker/Dockerfile.ci -t vllm-omni-ci ." + - "docker tag vllm-omni-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" + - "docker push public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" + agents: + queue: "cpu_queue_premerge" + + # - label: "Test on NPU" + # depends_on: ~ + # key: npu-test + # commands: + # - ".buildkite/scripts/hardware_ci/run_npu_test.sh" + # agents: + # queue: "ascend" + + - label: "Simple Unit Test" + depends_on: image-build + commands: + - pytest -v -s tests/entrypoints/ + - pytest -v -s tests/diffusion/cache/ + - pytest -v -s tests/diffusion/lora/ + - pytest -v -s tests/model_executor/models/qwen2_5_omni/test_audio_length.py + - pytest -v -s tests/worker/ + - pytest -v -s tests/distributed/omni_connectors/test_kv_flow.py + agents: + queue: "gpu_1_queue" + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Diffusion Model Test" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_t2i_model.py + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Diffusion Images API LoRA E2E" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/online_serving/test_images_generations_lora.py + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Diffusion Model CPU offloading Test" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_diffusion_cpu_offload.py + - pytest -s -v tests/e2e/offline_inference/test_diffusion_layerwise_offload.py + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Audio Generation Model Test" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Diffusion Cache Backend Test" + timeout_in_minutes: 15 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_cache_dit.py tests/e2e/offline_inference/test_teacache.py + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Diffusion Sequence Parallelism Test" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py + agents: + queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Diffusion Tensor Parallelism Test" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_zimage_tensor_parallel.py + agents: + queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "Diffusion GPU Worker Test" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/diffusion/test_diffusion_worker.py + agents: + queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + + - label: "Benchmark Test" + timeout_in_minutes: 15 + depends_on: image-build + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/benchmarks/test_serve_cli.py + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + - label: "Omni Model Test" + timeout_in_minutes: 15 + depends_on: image-build + commands: + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py + agents: + queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + # - label: "Omni Model Test with H100" + # timeout_in_minutes: 30 + # depends_on: image-build + # commands: + # - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # - export VLLM_TEST_CLEAN_GPU_MEMORY="1" + # - pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py + # - pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py + # - pytest -s -v tests/e2e/online_serving/test_async_omni.py + # agents: + # queue: "mithril-h100-pool" + # plugins: + # - kubernetes: + # podSpec: + # containers: + # - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + # resources: + # limits: + # nvidia.com/gpu: 2 + # volumeMounts: + # - name: devshm + # mountPath: /dev/shm + # - name: hf-cache + # mountPath: /root/.cache/huggingface + # env: + # - name: HF_HOME + # value: /root/.cache/huggingface + # nodeSelector: + # node.kubernetes.io/instance-type: gpu-h100-sxm + # volumes: + # - name: devshm + # emptyDir: + # medium: Memory + # - name: hf-cache + # hostPath: + # path: /mnt/hf-cache + # type: DirectoryOrCreate + + - label: "Diffusion Image Edit Test with H100 (1 GPU)" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + # - label: "Bagel Text2Img Model Test with H100" + # timeout_in_minutes: 30 + # depends_on: image-build + # commands: + # - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # - pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py + # agents: + # queue: "mithril-h100-pool" + # plugins: + # - kubernetes: + # podSpec: + # containers: + # - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + # resources: + # limits: + # nvidia.com/gpu: 1 + # volumeMounts: + # - name: devshm + # mountPath: /dev/shm + # - name: hf-cache + # mountPath: /root/.cache/huggingface + # env: + # - name: HF_HOME + # value: /root/.cache/huggingface + # nodeSelector: + # node.kubernetes.io/instance-type: gpu-h100-sxm + # volumes: + # - name: devshm + # emptyDir: + # medium: Memory + # - name: hf-cache + # hostPath: + # path: /mnt/hf-cache + # type: DirectoryOrCreate diff --git a/.buildkite/scripts/docker_login_ecr_public.sh b/.buildkite/scripts/docker_login_ecr_public.sh new file mode 100644 index 0000000000000000000000000000000000000000..51c5e1a5d5d06ad8000883c21cf4ad26fa36ae95 --- /dev/null +++ b/.buildkite/scripts/docker_login_ecr_public.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Helper function to safely login to ECR Public with per-job config isolation +# Uses DOCKER_CONFIG environment variable to prevent race conditions +# +# This script prevents the "device or resource busy" error by giving each +# Buildkite job its own isolated Docker config directory. +# +# Usage: +# source docker_login_ecr_public.sh && safe_docker_login_ecr_public + +set -euo pipefail + +# Configuration +ECR_REGISTRY="public.ecr.aws" + +setup_isolated_docker_config() { + # Use BUILDKITE_JOB_ID for job-specific isolation + # Fallback to PID if running outside Buildkite + local job_id="${BUILDKITE_JOB_ID:-$$}" + + # Set Docker config to job-specific directory + export DOCKER_CONFIG="/tmp/docker-config-${job_id}" + + # Create directory if it doesn't exist + mkdir -p "$DOCKER_CONFIG" + + echo "[docker-config] Using isolated Docker config: $DOCKER_CONFIG" +} + +check_docker_auth() { + # Check if already authenticated to the given registry + # Returns 0 if authenticated, 1 if not + local registry="$1" + + # Check if credentials exist in the isolated config + if [[ -f "$DOCKER_CONFIG/config.json" ]]; then + # Check if registry is present in config + if grep -q "$registry" "$DOCKER_CONFIG/config.json" 2>/dev/null; then + return 0 + fi + fi + + return 1 +} + +safe_docker_login_ecr_public() { + # Setup isolated config first + setup_isolated_docker_config + + local registry="$ECR_REGISTRY" + + # Check if already authenticated (within this job) + if check_docker_auth "$registry"; then + echo "[docker-login] Already authenticated to $registry in this job" + return 0 + fi + + # Perform login to isolated config directory + echo "[docker-login] Logging in to $ECR_REGISTRY (isolated config)..." + if aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin "$ECR_REGISTRY"; then + echo "[docker-login] Login successful (config: $DOCKER_CONFIG)" + return 0 + else + local exit_code=$? + echo "[docker-login] ERROR: Login failed with exit code $exit_code" >&2 + return $exit_code + fi +} + +# Execute if run as script (not sourced) +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + safe_docker_login_ecr_public +fi diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh new file mode 100644 index 0000000000000000000000000000000000000000..f86b4b5d95808861d9ec0c53755efb65aac0f1c8 --- /dev/null +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -0,0 +1,166 @@ +#!/bin/bash +# vllm-omni customized version +# Based on: vllm/.buildkite/scripts/hardware_ci/run-amd-test.sh +# Last synced: 2025-12-15 +# Modifications: docker image name for vllm-omni + +# This script runs test inside the corresponding ROCm docker container. +set -o pipefail + +# Export Python path +export PYTHONPATH=".." + +# Print ROCm version +echo "--- Confirming Clean Initial State" +while true; do + sleep 3 + if grep -q clean /opt/amdgpu/etc/gpu_state; then + echo "GPUs state is \"clean\"" + break + fi +done + +echo "--- ROCm info" +rocminfo + +# cleanup older docker images +cleanup_docker() { + # Get Docker's root directory + docker_root=$(docker info -f '{{.DockerRootDir}}') + if [ -z "$docker_root" ]; then + echo "Failed to determine Docker root directory." + exit 1 + fi + echo "Docker root directory: $docker_root" + # Check disk usage of the filesystem where Docker's root directory is located + disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') + # Define the threshold + threshold=70 + if [ "$disk_usage" -gt "$threshold" ]; then + echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune --force --filter "until=72h" --all + echo "Docker images and volumes cleanup completed." + else + echo "Disk usage is below $threshold%. No cleanup needed." + fi +} + +# Call the cleanup docker function +cleanup_docker + +echo "--- Resetting GPUs" + +echo "reset" > /opt/amdgpu/etc/gpu_state + +while true; do + sleep 3 + if grep -q clean /opt/amdgpu/etc/gpu_state; then + echo "GPUs state is \"clean\"" + break + fi +done + +echo "--- Pulling container" +image_name="public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${BUILDKITE_COMMIT}-rocm-omni" +container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" + +# Install AWS CLI to authenticate to ECR Public Gallery to get higher rate limit for pulling images +sudo apt-get update && sudo apt-get install -y awscli +# Use safe docker login helper to prevent race conditions +source "$(dirname "${BASH_SOURCE[0]}")/../docker_login_ecr_public.sh" +safe_docker_login_ecr_public +# Pull the container from ECR Public Gallery + +docker pull "${image_name}" + +remove_docker_container() { + docker rm -f "${container_name}" || docker image rm -f "${image_name}" || true +} +trap remove_docker_container EXIT + +echo "--- Running container" + +HF_CACHE="$(realpath ~)/huggingface" +mkdir -p "${HF_CACHE}" +HF_MOUNT="/root/.cache/huggingface" + +commands=$@ +echo "Commands:$commands" + +PARALLEL_JOB_COUNT=8 +MYPYTHONPATH=".." + +# Test that we're launching on the machine that has +# proper access to GPUs +render_gid=$(getent group render | cut -d: -f3) +if [[ -z "$render_gid" ]]; then + echo "Error: 'render' group not found. This is required for GPU access." >&2 + exit 1 +fi + +# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. +if [[ $commands == *"--shard-id="* ]]; then + # assign job count as the number of shards used + commands=$(echo "$commands" | sed -E "s/--num-shards[[:blank:]]*=[[:blank:]]*[0-9]*/--num-shards=${PARALLEL_JOB_COUNT} /g" | sed 's/ \\ / /g') + for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do + # assign shard-id for each shard + commands_gpu=$(echo "$commands" | sed -E "s/--shard-id[[:blank:]]*=[[:blank:]]*[0-9]*/--shard-id=${GPU} /g" | sed 's/ \\ / /g') + echo "Shard ${GPU} commands:$commands_gpu" + echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES" + docker run \ + --device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \ + --network=host \ + --shm-size=16gb \ + --group-add "$render_gid" \ + --rm \ + -e MIOPEN_DEBUG_CONV_DIRECT=0 \ + -e MIOPEN_DEBUG_CONV_GEMM=0 \ + -e VLLM_ROCM_USE_AITER=1 \ + -e HIP_VISIBLE_DEVICES="${GPU}" \ + -e HF_TOKEN \ + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ + -v "${HF_CACHE}:${HF_MOUNT}" \ + -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ + --name "${container_name}_${GPU}" \ + "${image_name}" \ + /bin/bash -c "${commands_gpu}" \ + |& while read -r line; do echo ">>Shard $GPU: $line"; done & + PIDS+=($!) + done + #wait for all processes to finish and collect exit codes + for pid in "${PIDS[@]}"; do + wait "${pid}" + STATUS+=($?) + done + for st in "${STATUS[@]}"; do + if [[ ${st} -ne 0 ]]; then + echo "One of the processes failed with $st" + exit "${st}" + fi + done +else + echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES" + docker run \ + --device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \ + --network=host \ + --shm-size=16gb \ + --group-add "$render_gid" \ + --rm \ + -e MIOPEN_DEBUG_CONV_DIRECT=0 \ + -e MIOPEN_DEBUG_CONV_GEMM=0 \ + -e VLLM_ROCM_USE_AITER=1 \ + -e HF_TOKEN \ + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ + -v "${HF_CACHE}:${HF_MOUNT}" \ + -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ + --name "${container_name}" \ + "${image_name}" \ + /bin/bash -c "${commands}" +fi diff --git a/.buildkite/scripts/hardware_ci/run_npu_test.sh b/.buildkite/scripts/hardware_ci/run_npu_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..fbe24badbb5d80917182cc46d23c60ea242d9bb1 --- /dev/null +++ b/.buildkite/scripts/hardware_ci/run_npu_test.sh @@ -0,0 +1,145 @@ +#!/bin/bash + +# This script build the Ascend NPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Base ubuntu image with basic ascend development libraries and python installed +VLLM_OMNI_REPO="https://github.com/vllm-project/vllm-omni.git" +BASE_IMAGE_NAME="quay.nju.edu.cn/ascend/vllm-ascend:v0.11.0rc2" +image_name="npu/vllm-omni-ci:${BUILDKITE_COMMIT}_${EPOCHSECONDS}" +# image_name="npu/vllm-ci:${BUILDKITE_COMMIT}_${EPOCHSECONDS}" +container_name="npu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" + +# BUILDKITE_AGENT_NAME format is {hostname}-{agent_idx}-{npu_card_num}cards +agent_idx=$(echo "${BUILDKITE_AGENT_NAME}" | awk -F'-' '{print $(NF-1)}') +echo "agent_idx: ${agent_idx}" +builder_name="cachebuilder${agent_idx}" +builder_cache_dir="/mnt/docker-cache${agent_idx}" +mkdir -p ${builder_cache_dir} + +# Try building the docker image +cat <=6.0 pytest-cov modelscope + +COPY . . + +# Install vllm-omni +WORKDIR /workspace +ARG VLLM_OMNI_REPO=https://github.com/vllm-project/vllm-omni.git +ARG VLLM_OMNI_TAG=main +ARG BUILDKITE_PULL_REQUEST +ARG BUILDKITE_PULL_REQUEST_REPO +RUN git config --global url."https://gh-proxy.test.osinfra.cn/https://github.com/".insteadOf "https://github.com/" && \ + if [ "\$BUILDKITE_PULL_REQUEST" != "false" ] && [ -n "\$BUILDKITE_PULL_REQUEST" ]; then \ + echo "Cloning and checking out PR #\$BUILDKITE_PULL_REQUEST..." && \ + git clone \$VLLM_OMNI_REPO /workspace/vllm-omni && \ + cd /workspace/vllm-omni && \ + git fetch origin pull/\$BUILDKITE_PULL_REQUEST/head:pr-\$BUILDKITE_PULL_REQUEST && \ + git checkout pr-\$BUILDKITE_PULL_REQUEST; \ + else \ + echo "Not a PR build, using main branch" && \ + git clone --depth 1 \$VLLM_OMNI_REPO /workspace/vllm-omni; \ + fi + +RUN --mount=type=cache,target=/root/.cache/pip \ + export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + python3 -m pip install -v -e /workspace/vllm-omni/ + +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn + +WORKDIR /workspace/vllm-omni +CMD ["/bin/bash"] + +EOF + +# Setup cleanup +remove_docker_container() { + docker rm -f "${container_name}" || true; + docker image rm -f "${image_name}" || true; + docker system prune -f || true; +} +trap remove_docker_container EXIT + +# Generate corresponding --device args based on BUILDKITE_AGENT_NAME +# Ascend NPU BUILDKITE_AGENT_NAME format is {hostname}-{agent_idx}-{npu_card_num}cards, and agent_idx starts from 1. +# e.g. atlas-a2-001-1-2cards means this is the 1-th agent on atlas-a2-001 host, and it has 2 NPU cards. +# returns --device /dev/davinci0 --device /dev/davinci1 +parse_and_gen_devices() { + local input="$1" + local index cards_num + if [[ "$input" =~ ([0-9]+)-([0-9]+)cards$ ]]; then + index="${BASH_REMATCH[1]}" + cards_num="${BASH_REMATCH[2]}" + else + echo "parse error" >&2 + return 1 + fi + + local devices="" + local i=0 + while (( i < cards_num )); do + local dev_idx=$(((index - 1)*cards_num + i )) + devices="$devices --device /dev/davinci${dev_idx}" + ((i++)) + done + + # trim leading space + devices="${devices#"${devices%%[![:space:]]*}"}" + # Output devices: assigned to the caller variable + printf '%s' "$devices" +} + +devices=$(parse_and_gen_devices "${BUILDKITE_AGENT_NAME}") || exit 1 + +# Run the image and execute the Out-Of-Tree (OOT) platform interface test case on Ascend NPU hardware. +# This test checks whether the OOT platform interface is functioning properly in conjunction with +# the hardware plugin vllm-ascend. +hf_model_cache_dir=/mnt/hf_cache${agent_idx} +ms_model_cache_dir=/mnt/modelscope${agent_idx} +mkdir -p ${hf_model_cache_dir} +mkdir -p ${ms_model_cache_dir} +docker run \ + --init \ + ${devices} \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v ${hf_model_cache_dir}:/root/.cache/huggingface \ + -v ${ms_model_cache_dir}:/root/.cache/modelscope \ + --network host \ + --entrypoint="" \ + --name "${container_name}" \ + "${image_name}" \ + bash -c ' + set -e + VLLM_USE_MODELSCOPE=True pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py +' diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fac5c7268bf0a1acf718af7f98696f09e4a6d2f6 --- /dev/null +++ b/.buildkite/test-amd.yaml @@ -0,0 +1,116 @@ +steps: + +- label: "Diffusion Model Test" + timeout_in_minutes: 20 + agent_pool: mi325_2 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - pytest -s -v tests/e2e/offline_inference/test_t2i_model.py + +- label: "Diffusion Images API LoRA E2E" + timeout_in_minutes: 20 + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/online_serving/test_images_generations_lora.py + +- label: "Diffusion Model CPU offloading Test" + timeout_in_minutes: 20 + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/offline_inference/test_diffusion_cpu_offload.py + +- label: "Diffusion Cache Backend Test" + timeout_in_minutes: 15 + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/offline_inference/test_cache_dit.py tests/e2e/offline_inference/test_teacache.py + +- label: "Diffusion Sequence Parallelism Test" + timeout_in_minutes: 20 + agent_pool: mi325_2 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py + +- label: "Diffusion Tensor Parallelism Test" + timeout_in_minutes: 20 + agent_pool: mi325_2 + depends_on: amd-build + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/offline_inference/test_zimage_tensor_parallel.py + +- label: "Diffusion GPU Worker Test" + timeout_in_minutes: 20 + agent_pool: mi325_2 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - pytest -s -v tests/diffusion/test_diffusion_worker.py + +- label: "Omni Model Test Qwen2-5-Omni" + timeout_in_minutes: 15 + agent_pool: mi325_2 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py + +- label: "Omni Model Test Qwen3-Omni" + timeout_in_minutes: 15 + agent_pool: mi325_2 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - export VLLM_TEST_CLEAN_GPU_MEMORY="1" + - pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py + - pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py + - pytest -s -v tests/e2e/online_serving/test_async_omni.py + +- label: "Diffusion Image Edit Test" + timeout_in_minutes: 15 + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py diff --git a/.buildkite/test-template-amd-omni.j2 b/.buildkite/test-template-amd-omni.j2 new file mode 100644 index 0000000000000000000000000000000000000000..6442ff5441fb8393012f20e8b4bcaa36cad67794 --- /dev/null +++ b/.buildkite/test-template-amd-omni.j2 @@ -0,0 +1,53 @@ +{# vllm-omni customized version + Based on: https://github.com/vllm-project/ci-infra/blob/main/buildkite/test-template-amd.j2 + Last synced: 2025-12-15 + Modifications: Removed unused CUDA/NVIDIA logic, keeping only AMD tests +#} +{% set docker_image_amd = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT-rocm-omni" %} +{% set default_working_dir = "/app/vllm-omni" %} + + - group: "AMD Tests" + depends_on: ~ + steps: + - label: "AMD: :docker: build image" + depends_on: ~ + soft_fail: false + commands: + - "source .buildkite/scripts/docker_login_ecr_public.sh && safe_docker_login_ecr_public" + - "docker build -f docker/Dockerfile.rocm -t {{ docker_image_amd }} --target final --progress plain ." + - "docker push {{ docker_image_amd }}" + key: "amd-build" + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 1 + - exit_status: -10 # Agent was lost + limit: 1 + - exit_status: 1 # Machine occasionally fail + limit: 1 + agents: + queue: cpu_queue_premerge + + {% for step in steps %} + {% if step.mirror_hardwares and mirror_hw in step.mirror_hardwares %} + - label: "{{ step.agent_pool }}: {{ step.label }}" + depends_on: amd-build + agents: + {% if step.agent_pool %} + queue: amd_{{ step.agent_pool }} + {% else %} + queue: amd_mi325_1 + {% endif %} + command: bash .buildkite/scripts/hardware_ci/run-amd-test.sh "(command rocm-smi || true) && cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" && ")) | safe }}" + env: + DOCKER_BUILDKIT: "1" + priority: 100 + {% if step.grade and step.grade == "Blocking" %} + soft_fail: false + {% else %} + soft_fail: true + {% endif%} + {% endif %} + {% endfor %} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e13098f6486a3016d2afdabc0419e4b9d2a7eda --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,70 @@ +default_install_hook_types: + - pre-commit + - commit-msg +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: check-yaml + args: ["--unsafe"] + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + args: ["--fix=lf"] + - id: trailing-whitespace + args: ["--markdown-linebreak-ext=md"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.10 + hooks: + - id: ruff-check + args: [--output-format, github, --fix] + - id: ruff-format + + - repo: https://github.com/crate-ci/typos + rev: typos-dict-v0.13.13 + hooks: + - id: typos + # only for staged files + + - repo: https://github.com/rhysd/actionlint + # v1.7.8+ sets `go 1.24.0` in go.mod, which older Go toolchains (and most + # current CI images) cannot parse. Pin to v1.7.7 until actionlint fixes the + # go.mod directive. + rev: v1.7.7 + hooks: + - id: actionlint + files: ^\.github/workflows/.*\.ya?ml$ + + + - repo: local + hooks: + - id: signoff-commit + name: Sign-off Commit + entry: bash + args: + - -c + - | + if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then + printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)" + fi + language: system + verbose: true + stages: [commit-msg] + + # Keep `suggestion` last + - id: suggestion + name: Suggestion + entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=."' + language: system + verbose: true + pass_filenames: false + # Insert new entries above the `suggestion` entry + + - id: check-pickle-imports + name: Prevent new pickle/cloudpickle imports + entry: python tools/pre_commit/check_pickle_imports.py + language: python + types: [python] + additional_dependencies: [regex] diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000000000000000000000000000000000000..5a06c663409d402e906ef6be56cf9f8921d538c1 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,24 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.12" + jobs: + post_checkout: + - git fetch --unshallow || true + +mkdocs: + configuration: mkdocs.yml + fail_on_warning: true + +# Optionally declare the Python requirements required to build your docs +python: + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..aac7497757c753f5eac1b043083ee3ede182077c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,3 @@ +# Contributing to vLLM-Omni + +You may find information about contributing to vLLM-Omni on [Contributing](https://vllm-omni.readthedocs.io/en/latest/contributing/) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 4fe8f0b2fbb83734f05bdb0a82b22313b096c0b7..a61a45032b0144e76e2dbf60c331a2c9e311f0be 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,93 @@ -# vllm-omni +

+ + + vllm-omni + +

+

+Easy, fast, and cheap omni-modality model serving for everyone +

-vLLM 最初是为支持文本生成任务的大型语言模型而设计的。vLLM-Omni 是一个框架,它将 vLLM 的支持扩展到全模态模型推理和服务的领域。 \ No newline at end of file +

+| Documentation | User Forum | Developer Slack | WeChat | +

+ +--- + +*Latest News* 🔥 + +- [2026/02] We released [0.14.0](https://github.com/vllm-project/vllm-omni/releases/tag/v0.14.0) - This is the first **stable release** of vLLM-Omni that expands Omni’s diffusion / image-video generation and audio / TTS stack, improves distributed execution and memory efficiency, and broadens platform/backend coverage (GPU/ROCm/NPU/XPU). It also brings meaningful upgrades to serving APIs, profiling & benchmarking, and overall stability. Please check our latest [paper](https://arxiv.org/abs/2602.02204) for architecture design and performance results. +- [2026/01] We released [0.12.0rc1](https://github.com/vllm-project/vllm-omni/releases/tag/v0.12.0rc1) - a major RC milestone focused on maturing the diffusion stack, strengthening OpenAI-compatible serving, expanding omni-model coverage, and improving stability across platforms (GPU/NPU/ROCm), please check our latest [design](https://docs.google.com/presentation/d/1qv4qMW1rKAqDREMXiUDLIgqqHQe7TDPj/edit?usp=sharing&ouid=110473603432222024453&rtpof=true&sd=true). +- [2025/11] vLLM community officially released [vllm-project/vllm-omni](https://github.com/vllm-project/vllm-omni) in order to support omni-modality models serving. + +--- + +## About + +[vLLM](https://github.com/vllm-project/vllm) was originally designed to support large language models for text-based autoregressive generation tasks. vLLM-Omni is a framework that extends its support for omni-modality model inference and serving: + +- **Omni-modality**: Text, image, video, and audio data processing +- **Non-autoregressive Architectures**: extend the AR support of vLLM to Diffusion Transformers (DiT) and other parallel generation models +- **Heterogeneous outputs**: from traditional text generation to multimodal outputs + +

+ + vllm-omni + +

+ +vLLM-Omni is fast with: + +- State-of-the-art AR support by leveraging efficient KV cache management from vLLM +- Pipelined stage execution overlapping for high throughput performance +- Fully disaggregation based on OmniConnector and dynamic resource allocation across stages + +vLLM-Omni is flexible and easy to use with: + +- Heterogeneous pipeline abstraction to manage complex model workflows +- Seamless integration with popular Hugging Face models +- Tensor, pipeline, data and expert parallelism support for distributed inference +- Streaming outputs +- OpenAI-compatible API server + +vLLM-Omni seamlessly supports most popular open-source models on HuggingFace, including: + +- Omni-modality models (e.g. Qwen-Omni) +- Multi-modality generation models (e.g. Qwen-Image) + +## Getting Started + +Visit our [documentation](https://vllm-omni.readthedocs.io/en/latest/) to learn more. + +- [Installation](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/) +- [Quickstart](https://vllm-omni.readthedocs.io/en/latest/getting_started/quickstart/) +- [List of Supported Models](https://vllm-omni.readthedocs.io/en/latest/models/supported_models/) + +## Contributing + +We welcome and value any contributions and collaborations. +Please check out [Contributing to vLLM-Omni](https://vllm-omni.readthedocs.io/en/latest/contributing/) for how to get involved. + +## Citation + +If you use vLLM-Omni for your research, please cite our [paper](https://arxiv.org/abs/2602.02204): + +```bibtex +@article{yin2026vllmomni, + title={vLLM-Omni: Fully Disaggregated Serving for Any-to-Any Multimodal Models}, + author={Peiqi Yin, Jiangyun Zhu, Han Gao, Chenguang Zheng, Yongxiang Huang, Taichang Zhou, Ruirui Yang, Weizhi Liu, Weiqing Chen, Canlin Guo, Didan Deng, Zifeng Mo, Cong Wang, James Cheng, Roger Wang, Hongsheng Liu}, + journal={arXiv preprint arXiv:2602.02204}, + year={2026} +} +``` + +## Join the Community +Feel free to ask questions, provide feedbacks and discuss with fellow users of vLLM-Omni in `#sig-omni` slack channel at [slack.vllm.ai](https://slack.vllm.ai) or vLLM user forum at [discuss.vllm.ai](https://discuss.vllm.ai). + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=vllm-project/vllm-omni&type=date&legend=top-left)](https://www.star-history.com/#vllm-project/vllm-omni&type=date&legend=top-left) + +## License + +Apache License 2.0, as found in the [LICENSE](./LICENSE) file. diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5ff270b7fc75b82f41e71017cb982c08ffdc5d5b --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,42 @@ +# Benchmarks Overview and Architecture + +This document explains the benchmark architecture across all benchmark assets in this repo. It describes what we measure, and where to find or plug in new scenarios. Per-task details remain in subfolder READMEs (e.g., `benchmarks//README.md`). + +## Scope and goals +- Establish repeatable latency/throughput measurements for multimodal LLM pipelines. +- Provide both HF Transformers (offline) and vLLM-Omni (multi-stage/pipeline) baselines. +- Make it easy to plug in new datasets and models with minimal changes to the runner scripts. + +## Dataset and inputs +- Default example: SeedTTS top-100 prompts (`benchmarks/build_dataset/top100.txt`) via `benchmarks/build_dataset/`. +- Extensible: drop in new prompt files or modality-aligned payloads; keep the expected format for the consuming scripts (e.g., one prompt per line). +- If you add a new dataset, document it under `benchmarks//README.md` and point scripts to your data path. + +## Directory layout +- `benchmarks/build_dataset/` — dataset prep utilities (e.g., SeedTTS top100). +- `benchmarks//vllm_omni/` — vLLM-Omni pipeline benchmarks, logs, outputs. +- Add new tasks under `benchmarks//...` with the same pattern: `transformers/`, `vllm_omni/`, task-specific README, and (optionally) dataset prep notes. + +## Reference workflows +- **HF Transformers (offline, single process)** + Script (example): `benchmarks//transformers/eval_qwen3_moe_omni_transformers.sh` + Outputs: `benchmark_results/perf_stats.json`, `benchmark_results/results.json`, `benchmark_results/audio/` (if audio is produced). + +- **vLLM-Omni end-to-end pipeline** + Script (example): `benchmarks//vllm_omni/eval_qwen3_moe_omni.sh` + Outputs: `vllm_omni/logs/*.stats.jsonl` (per-stage/overall latency & TPS), `vllm_omni/logs/stage*.log`, `vllm_omni/outputs/` (text/audio artifacts). + +- **Adding a new task/model** + 1) Create `benchmarks//transformers/` and/or `benchmarks//vllm_omni/` with scripts referencing your model and dataset. + 2) Add a task README describing dataset, configs, and expected outputs. + 3) Keep the output/log structure similar for easy comparison (perf_stats/results/audio or text outputs; stats.jsonl/logs for pipeline). + +## Metrics to watch +- **Throughput**: `overall_tps`, `*_tps_avg` per stage. +- **Latency distribution**: look for long tails in `*.stats.jsonl`. +- **Quality/completeness**: missing outputs or errors in stage logs indicate pipeline failures or misconfigurations. + +## Troubleshooting +- Verify GPU/driver/FlashAttention2 requirements for your chosen model/config. +- Ensure network access for dataset/model downloads (Google Drive, Hugging Face, etc.). +- If outputs are missing or slow, inspect per-stage logs and `*.stats.jsonl` for errors, stragglers, or contention. diff --git a/benchmarks/build_dataset/download_process_data_seedtts.md b/benchmarks/build_dataset/download_process_data_seedtts.md new file mode 100644 index 0000000000000000000000000000000000000000..54bb0f9c9a49447afd8aaf14a9ea4831c2089020 --- /dev/null +++ b/benchmarks/build_dataset/download_process_data_seedtts.md @@ -0,0 +1,82 @@ +# Benchmark Dataset Preparation Guide + +This guide describes how to download and prepare the SeedTTS test dataset for benchmarking Qwen-Omni models. + +## Prerequisites + +- Python 3.8+ +- `gdown` for downloading from Google Drive +- Access to the benchmark scripts + +## Steps + +### 1. Navigate to the Dataset Directory + +```bash +cd benchmarks/build_dataset +``` + +### 2. Install Dependencies + +```bash +pip install gdown +``` + +### 3. Download the SeedTTS Test Dataset + +Download the dataset from Google Drive: + +```bash +gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP +``` + +### 4. Extract the Dataset + +```bash +tar -xf seedtts_testset.tar +``` + +### 5. Prepare the Metadata File + +Copy the English metadata file to the working directory: + +```bash +cp seedtts_testset/en/meta.lst meta.lst +``` + +### 6. Extract Prompts + +Extract the first N prompts from the metadata file: + +```bash +# Extract top 100 prompts (adjust -n for different amounts) +python extract_tts_prompts.py -i meta.lst -o top100.txt -n 100 +``` + +**Options:** +- `-i, --input`: Input metadata file (default: `meta.lst`) +- `-o, --output`: Output prompts file (default: `prompts.txt`) +- `-n, --num_lines`: Number of prompts to extract (required) + +### 7. Clean Up (Optional) + +Remove temporary files to save disk space: + +```bash +rm -rf seedtts_testset +rm seedtts_testset.tar +rm meta.lst +``` + +## Quick Start (All-in-One) + +```bash +# Full setup and benchmark +cd benchmarks/build_dataset +pip install gdown +gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP +tar -xf seedtts_testset.tar +cp seedtts_testset/en/meta.lst meta.lst +python extract_tts_prompts.py -i meta.lst -o top100.txt -n 100 +rm -rf seedtts_testset seedtts_testset.tar meta.lst +``` diff --git a/benchmarks/build_dataset/extract_tts_prompts.py b/benchmarks/build_dataset/extract_tts_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca3d190821295d03dbf1a6315a214b410d4d375 --- /dev/null +++ b/benchmarks/build_dataset/extract_tts_prompts.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +""" +Extract prompts from meta.lst and save them to a txt file. + +Each line in meta.lst has the format: +ID|prompt_text|audio_path|target_text + +This script extracts the prompt_text (second field) from the first N lines. +""" + +import argparse +from pathlib import Path + + +def extract_prompts(input_file: str, output_file: str, num_lines: int) -> None: + """ + Extract prompts from meta.lst and save to output file. + + Args: + input_file: Path to the meta.lst file + output_file: Path to the output txt file + num_lines: Number of lines to process + """ + prompts = [] + + with open(input_file, encoding="utf-8") as f: + for i, line in enumerate(f): + if i >= num_lines: + break + + line = line.strip() + if not line: # Skip empty lines + continue + + parts = line.split("|") + if len(parts) >= 2: + prompt = parts[1] # The prompt is the second field + prompts.append(prompt) + + # Write prompts to output file + with open(output_file, "w", encoding="utf-8") as f: + for prompt in prompts: + f.write(prompt + "\n") + + # Print result stats + print(f"Extracted {len(prompts)} prompts from first {num_lines} lines") + print(f"Saved to: {output_file}") + + +def main(): + parser = argparse.ArgumentParser(description="Extract prompts from meta.lst file") + parser.add_argument( + "-i", "--input", type=str, default="meta.lst", help="Input meta.lst file path (default: meta.lst)" + ) + parser.add_argument( + "-o", "--output", type=str, default="prompts.txt", help="Output txt file path (default: prompts.txt)" + ) + parser.add_argument( + "-n", "--num_lines", type=int, required=True, help="Number of lines to extract from the beginning" + ) + + args = parser.parse_args() + + # Check if input file exists + if not Path(args.input).exists(): + print(f"Error: Input file '{args.input}' not found") + return + + extract_prompts(args.input, args.output, args.num_lines) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/diffusion/README.md b/benchmarks/diffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3eb775fe23fb1733775f88d5d356d58a10fda126 --- /dev/null +++ b/benchmarks/diffusion/README.md @@ -0,0 +1,117 @@ + +# Diffusion Serving Benchmark (Image/Video) + +This folder contains an online-serving benchmark script for diffusion models. +It sends requests to a vLLM OpenAI-compatible endpoint and reports throughput, +latency percentiles, and optional SLO attainment. + +The main entrypoint is: + +- `benchmarks/diffusion/diffusion_benchmark_serving.py` + +## 1. Quick Start + +1. Start the server: + +```bash +vllm serve Qwen/Qwen-Image --omni --port 8099 +``` + +2. Run a minimal benchmark: + +```bash +python3 benchmarks/diffusion/diffusion_benchmark_serving.py \ + --base-url http://localhost:8099 \ + --model Qwen/Qwen-Image \ + --task t2i \ + --dataset vbench \ + --num-prompts 5 +``` + +**Notes** + +- The benchmark talks to `http://:/v1/chat/completions`. +- If you run the server on another host or port, pass `--base-url` accordingly. + +## 2. Supported Datasets + +The benchmark supports three dataset modes via `--dataset`: + +- `vbench`: Built-in prompt/data loader. +- `trace`: Heterogeneous request traces (each request can have different resolution/frames/steps). +- `random`: Synthetic prompts for quick smoke tests. + +### VBench dataset + +If you use i2v/i2i bench datasets and need auto-download support, you may need: + +```bash +uv pip install gdown +``` + +### Trace dataset + +Use `--dataset trace` to replay a trace file. The trace can specify per-request fields such as: + +- `width`, `height` +- `num_frames` (video) +- `num_inference_steps` +- `seed`, `fps` +- optional `slo_ms` (per-request SLO target) + +By default (when `--dataset-path` is not provided), the script downloads a default trace from +the HuggingFace dataset repo `asukaqaqzz/Dit_Trace`. The default filename can depend on `--task` +(e.g., `t2v` uses a video trace). + +Current defaults: + +- `--task t2i` -> `sd3_trace.txt` +- `--task t2v` -> `cogvideox_trace.txt` + +You can point to your own trace using `--dataset-path`. + +## 3. Benchmark Parameters + +### Basic flags + +- `--base-url`: Server address (the script calls `.../v1/chat/completions`). +- `--model`: The OpenAI-compatible `model` field. +- `--task`: Task type (e.g., `t2i`, `t2v`, `i2i`, `i2v`). +- `--dataset`: Dataset mode (`vbench` / `trace` / `random`). +- `--num-prompts`: Number of requests to send. + +Common optional flags: + +- `--output-file`: Write metrics to a JSON file. +- `--disable-tqdm`: Disable the progress bar. + +### Resolution / frames / steps: CLI defaults vs dataset fields + +Related flags: `--width`, `--height`, `--num-frames`, `--fps`, `--num-inference-steps`. + +- For `vbench` / `random`: these CLI flags act as global defaults for all generated requests. +- For `trace`: each request can carry its own fields (e.g., `width/height/num_frames/num_inference_steps`). + +Precedence rules for `trace` (i.e., what actually gets sent): + +- `width/height`: if either `--width` or `--height` is explicitly set, it overrides per-request values from the trace; otherwise per-request values are used when present. +- `num_frames`: per-request `num_frames` takes precedence; otherwise fall back to `--num-frames`. +- `num_inference_steps`: per-request `num_inference_steps` takes precedence; otherwise fall back to `--num-inference-steps`. + +### SLO, warmup, and max concurrency + +Enable SLO evaluation with `--slo`. + +- If a request in the trace already has `slo_ms`, that value is used. +- Otherwise, the script runs warmup requests to infer a base unit time, estimates `expected_ms` by linearly scaling with area/frames/steps, and then sets `slo_ms = expected_ms * --slo-scale`. + +Warmup flags: + +- `--warmup-requests`: Number of warmup requests. +- `--warmup-num-inference-steps`: Steps used during warmup. +- For `--task t2v`: warmup requests are forced to use `num_frames=1` to make warmup faster and less noisy. + +Traffic / concurrency flags: + +- `--request-rate`: Target request rate (requests/second). If set to `inf`, the script sends all requests immediately. +- `--max-concurrency`: Max number of in-flight requests (default: `1`). This can hard-cap the achieved QPS: if it is too small, requests will queue behind the semaphore, and both achieved throughput and observed SLO attainment can be skewed. diff --git a/benchmarks/diffusion/diffusion_benchmark_serving.py b/benchmarks/diffusion/diffusion_benchmark_serving.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb4cdcffa631c26627e9f1f92749be60d089857 --- /dev/null +++ b/benchmarks/diffusion/diffusion_benchmark_serving.py @@ -0,0 +1,1070 @@ +# adapted from sglang and fastvideo +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark online serving for diffusion models (Image/Video Generation). +If you want to use i2v, i2i dataset, you should `uv pip install gdown` first + +Usage: + # Video + t2v: + python3 benchmarks/diffusion/diffusion_benchmark_serving.py \ + --dataset vbench --task t2v --num-prompts 10 \ + --height 480 --width 640 --fps 16 --num-frames 80 + + i2v: + python3 benchmarks/diffusion/diffusion_benchmark_serving.py \ + --dataset vbench --task i2v --num-prompts 10 + + + # Image + t2i: + python3 benchmarks/diffusion/diffusion_benchmark_serving.py \ + --dataset vbench --task t2i --num-prompts 10 \ + --height 1024 --width 1024 + + i2i: + python3 benchmarks/diffusion/diffusion_benchmark_serving.py \ + --dataset vbench --task i2i --num-prompts 10 + +""" + +import argparse +import ast +import asyncio +import base64 +import glob +import json +import mimetypes +import os +import time +import uuid +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field, replace +from typing import Any + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + model: str + width: int | None = None + height: int | None = None + num_frames: int | None = None + num_inference_steps: int | None = None + seed: int | None = None + fps: int | None = None + timestamp: float | None = None + slo_ms: float | None = None + extra_body: dict[str, Any] = field(default_factory=dict) + image_paths: list[str] | None = None + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RequestFuncOutput: + success: bool = False + latency: float = 0.0 + error: str = "" + start_time: float = 0.0 + response_body: dict[str, Any] = field(default_factory=dict) + peak_memory_mb: float = 0.0 + slo_achieved: bool | None = None + + +class BaseDataset(ABC): + def __init__(self, args, api_url: str, model: str): + self.args = args + self.api_url = api_url + self.model = model + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def __getitem__(self, idx: int) -> RequestFuncInput: + pass + + @abstractmethod + def get_requests(self) -> list[RequestFuncInput]: + pass + + +class VBenchDataset(BaseDataset): + """ + Dataset loader for VBench prompts. + Supports t2v, i2v. + """ + + T2V_PROMPT_URL = ( + "https://raw.githubusercontent.com/Vchitect/VBench/master/prompts/prompts_per_dimension/subject_consistency.txt" + ) + I2V_DOWNLOAD_SCRIPT_URL = ( + "https://raw.githubusercontent.com/Vchitect/VBench/master/vbench2_beta_i2v/download_data.sh" + ) + + def __init__(self, args, api_url: str, model: str): + super().__init__(args, api_url, model) + self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "vllm-omni") + self.items = self._load_data() + + def _load_data(self) -> list[dict[str, Any]]: + if self.args.task == "t2v": + return self._load_t2v_prompts() + elif self.args.task in ["i2v", "ti2v", "ti2i", "i2i"]: + return self._load_i2v_data() + else: + return self._load_t2v_prompts() + + def _download_file(self, url: str, dest_path: str) -> None: + """Download a file from URL to destination path.""" + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + resp = requests.get(url) + resp.raise_for_status() + with open(dest_path, "w") as f: + f.write(resp.text) + + def _load_t2v_prompts(self) -> list[dict[str, Any]]: + path = self.args.dataset_path + + if not path: + path = os.path.join(self.cache_dir, "vbench_subject_consistency.txt") + if not os.path.exists(path): + print(f"Downloading VBench T2V prompts to {path}...") + try: + self._download_file(self.T2V_PROMPT_URL, path) + except Exception as e: + print(f"Failed to download VBench prompts: {e}") + return [{"prompt": "A cat sitting on a bench"}] * 50 + + prompts = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + prompts.append({"prompt": line}) + + return self._resize_data(prompts) + + def _auto_download_i2v_dataset(self) -> str: + """Auto-download VBench I2V dataset and return the dataset directory.""" + vbench_i2v_dir = os.path.join(self.cache_dir, "vbench_i2v", "vbench2_beta_i2v") + info_json_path = os.path.join(vbench_i2v_dir, "data", "i2v-bench-info.json") + + if os.path.exists(info_json_path): + return vbench_i2v_dir + + print(f"Downloading VBench I2V dataset to {vbench_i2v_dir}...") + try: + cache_root = os.path.join(self.cache_dir, "vbench_i2v") + script_path = os.path.join(cache_root, "download_data.sh") + + self._download_file(self.I2V_DOWNLOAD_SCRIPT_URL, script_path) + os.chmod(script_path, 0o755) + + print("Executing download_data.sh (this may take a while)...") + import subprocess + + result = subprocess.run( + ["bash", script_path], + cwd=cache_root, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Download script failed: {result.stderr}") + + print(f"Successfully downloaded VBench I2V dataset to {vbench_i2v_dir}") + except Exception as e: + print(f"Failed to download VBench I2V dataset: {e}") + print("Please manually download following instructions at:") + print("https://github.com/Vchitect/VBench/tree/master/vbench2_beta_i2v#22-download") + return None + + return vbench_i2v_dir if os.path.exists(info_json_path) else None + + def _load_from_i2v_json(self, json_path: str) -> list[dict[str, Any]]: + """Load I2V data from i2v-bench-info.json format.""" + with open(json_path) as f: + items = json.load(f) + + base_dir = os.path.dirname(os.path.dirname(json_path)) # Go up to vbench2_beta_i2v + origin_dir = os.path.join(base_dir, "data", "origin") + + data = [] + for item in items: + img_path = os.path.join(origin_dir, item.get("file_name", "")) + if os.path.exists(img_path): + data.append({"prompt": item.get("caption", ""), "image_path": img_path}) + else: + print(f"Warning: Image not found: {img_path}") + + print(f"Loaded {len(data)} I2V samples from VBench I2V dataset") + return data + + def _scan_directory_for_images(self, path: str) -> list[dict[str, Any]]: + """Scan directory for image files.""" + exts = ["*.jpg", "*.jpeg", "*.png", "*.webp"] + files = [] + + for ext in exts: + files.extend(glob.glob(os.path.join(path, ext))) + files.extend(glob.glob(os.path.join(path, ext.upper()))) + + # Also check in data/origin subdirectory + origin_dir = os.path.join(path, "data", "origin") + if os.path.exists(origin_dir): + files.extend(glob.glob(os.path.join(origin_dir, ext))) + files.extend(glob.glob(os.path.join(origin_dir, ext.upper()))) + + return [{"prompt": os.path.splitext(os.path.basename(f))[0], "image_path": f} for f in files] + + def _create_dummy_data(self) -> list[dict[str, Any]]: + """Create dummy data with a placeholder image in cache directory.""" + print("No I2V data found. Using dummy placeholders.") + + dummy_image = os.path.join(self.cache_dir, "dummy_image.jpg") + if not os.path.exists(dummy_image): + try: + from PIL import Image + + os.makedirs(self.cache_dir, exist_ok=True) + img = Image.new("RGB", (100, 100), color="red") + img.save(dummy_image) + print(f"Created dummy image at {dummy_image}") + except ImportError: + print("PIL not installed, cannot create dummy image.") + return [] + + return [{"prompt": "A moving cat", "image_path": dummy_image}] * 10 + + def _load_i2v_data(self) -> list[dict[str, Any]]: + """Load I2V data from VBench I2V dataset or user-provided path.""" + path = self.args.dataset_path + + # Auto-download if no path provided + if not path: + path = self._auto_download_i2v_dataset() + if not path: + return self._resize_data(self._create_dummy_data()) + + # Try to load from i2v-bench-info.json + info_json_candidates = [ + os.path.join(path, "data", "i2v-bench-info.json"), + path if path.endswith(".json") else None, + ] + + for json_path in info_json_candidates: + if json_path and os.path.exists(json_path): + try: + return self._resize_data(self._load_from_i2v_json(json_path)) + except Exception as e: + print(f"Failed to load {json_path}: {e}") + + # Fallback: scan directory for images + if os.path.isdir(path): + data = self._scan_directory_for_images(path) + if data: + return self._resize_data(data) + + # Last resort: dummy data + return self._resize_data(self._create_dummy_data()) + + def _resize_data(self, data: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Resize data to match num_prompts.""" + if not self.args.num_prompts: + return data + + if len(data) < self.args.num_prompts: + factor = (self.args.num_prompts // len(data)) + 1 + data = data * factor + + return data[: self.args.num_prompts] + + def __len__(self) -> int: + return len(self.items) + + def __getitem__(self, idx: int) -> RequestFuncInput: + item = self.items[idx] + image_paths = [item["image_path"]] if "image_path" in item else None + + return RequestFuncInput( + prompt=item.get("prompt", ""), + api_url=self.api_url, + model=self.model, + width=self.args.width, + height=self.args.height, + num_frames=self.args.num_frames, + num_inference_steps=self.args.num_inference_steps, + seed=self.args.seed, + fps=self.args.fps, + image_paths=image_paths, + ) + + def get_requests(self) -> list[RequestFuncInput]: + return [self[i] for i in range(len(self))] + + +class TraceDataset(BaseDataset): + """Trace-based dataset loader for heterogeneous diffusion requests.""" + + DEFAULT_REPO_ID = "asukaqaqzz/Dit_Trace" + DEFAULT_FILENAME = "sd3_trace.txt" + DEFAULT_FILENAME_BY_TASK: dict[str, str] = { + # Text-to-image traces (e.g., SD3) + "t2i": "sd3_trace.txt", + # Text-to-video traces (e.g., CogVideoX) + "t2v": "cogvideox_trace.txt", + } + + def __init__(self, args, api_url: str, model: str): + super().__init__(args, api_url, model) + self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "vllm-omni", "trace") + self.default_filename = self.DEFAULT_FILENAME_BY_TASK.get(getattr(args, "task", ""), self.DEFAULT_FILENAME) + dataset_root = args.dataset_path + if not dataset_root: + dataset_root = self._download_default_trace() + self.items = self._load_items(dataset_root) + + @staticmethod + def _coerce_int(x: Any) -> int | None: + if x is None: + return None + if isinstance(x, bool): + return None + if isinstance(x, int): + return x + try: + s = str(x).strip() + if not s: + return None + return int(float(s)) + except Exception: + return None + + @staticmethod + def _coerce_float(x: Any) -> float | None: + if x is None: + return None + if isinstance(x, float): + return x + if isinstance(x, int): + return float(x) + try: + s = str(x).strip() + if not s: + return None + return float(s) + except Exception: + return None + + def _download_default_trace(self) -> str: + """Download default trace file from HuggingFace Hub if not provided.""" + + try: + from huggingface_hub import hf_hub_download + except ImportError as exc: + raise ImportError( + "huggingface_hub is required to download the default trace dataset. " + "Install via `pip install huggingface_hub`." + ) from exc + + os.makedirs(self.cache_dir, exist_ok=True) + return hf_hub_download( + repo_id=self.DEFAULT_REPO_ID, + filename=self.default_filename, + repo_type="dataset", + local_dir=self.cache_dir, + local_dir_use_symlinks=False, + ) + + def _expand_paths(self, dataset_path: str | None) -> list[str]: + if not dataset_path: + return [] + + parts = [p.strip() for p in str(dataset_path).split(",") if p.strip()] + paths: list[str] = [] + for p in parts: + if any(ch in p for ch in ["*", "?", "["]): + paths.extend(sorted(glob.glob(p))) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "**", "*.txt"), recursive=True))) + else: + paths.append(p) + + seen = set() + unique_paths = [] + for p in paths: + if p not in seen: + seen.add(p) + unique_paths.append(p) + return unique_paths + + def _parse_trace_file(self, path: str) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + + def parse_request_repr_line(line: str) -> dict[str, Any] | None: + text = line.strip() + if not text: + return None + if not (text.startswith("Request(") and text.endswith(")")): + return None + inner = text[len("Request(") : -1] + try: + expr = ast.parse(f"f({inner})", mode="eval") + if not isinstance(expr.body, ast.Call): + return None + call = expr.body + out: dict[str, Any] = {} + for kw in call.keywords: + if kw.arg is None: + continue + out[kw.arg] = ast.literal_eval(kw.value) + return out + except Exception: + return None + + # detect first non-empty line to pick parser + first_non_empty = None + with open(path, encoding="utf-8") as f: + for _ in range(50): + pos = f.tell() + line = f.readline() + if not line: + break + if line.strip(): + first_non_empty = line.strip() + f.seek(pos) + break + + if first_non_empty is None: + return rows + + if first_non_empty.startswith("Request("): + with open(path, encoding="utf-8") as f: + for line in f: + parsed = parse_request_repr_line(line) + if isinstance(parsed, dict): + rows.append(parsed) + return rows + + # txt fallback: parse Request(...) lines only + with open(path, encoding="utf-8") as f: + for line in f: + parsed = parse_request_repr_line(line) + if isinstance(parsed, dict): + rows.append(parsed) + return rows + + def _load_items(self, dataset_root: str) -> list[dict[str, Any]]: + paths = self._expand_paths(dataset_root) + if not paths: + raise ValueError("No trace files found. Provide --dataset-path or rely on default HuggingFace download.") + + items: list[dict[str, Any]] = [] + for p in paths: + if not os.path.exists(p): + continue + for row in self._parse_trace_file(p): + if isinstance(row, dict): + row = dict(row) + row.setdefault("_source", p) + items.append(row) + + if not items: + raise ValueError("Trace dataset is empty after parsing provided paths.") + + if self.args.num_prompts is not None: + items = items[: self.args.num_prompts] + + return items + + def __len__(self) -> int: + return len(self.items) + + def __getitem__(self, idx: int) -> RequestFuncInput: + row = self.items[idx] + prompt = row.get("prompt") or row.get("text") or "" + + row_height = self._coerce_int(row.get("height")) + row_width = self._coerce_int(row.get("width")) + num_frames = self._coerce_int(row.get("num_frames")) + num_steps = self._coerce_int(row.get("num_inference_steps")) + seed = self._coerce_int(row.get("seed")) + fps = self._coerce_int(row.get("fps")) + timestamp = self._coerce_float(row.get("timestamp")) + slo_ms = self._coerce_float(row.get("slo_ms")) + image_paths = row.get("image_paths") + + override_w = self.args.width + override_h = self.args.height + if override_w is not None or override_h is not None: + width = override_w + height = override_h + else: + width = row_width + height = row_height + + return RequestFuncInput( + prompt=str(prompt), + api_url=self.api_url, + model=self.model, + width=width, + height=height, + num_frames=num_frames if num_frames is not None else self.args.num_frames, + num_inference_steps=num_steps if num_steps is not None else self.args.num_inference_steps, + seed=seed if seed is not None else self.args.seed, + fps=fps if fps is not None else self.args.fps, + timestamp=timestamp, + slo_ms=slo_ms, + image_paths=image_paths, + request_id=str(row.get("request_id")) if row.get("request_id") is not None else str(uuid.uuid4()), + ) + + def get_requests(self) -> list[RequestFuncInput]: + return [self[i] for i in range(len(self))] + + +class RandomDataset(BaseDataset): + def __init__(self, args, api_url: str, model: str): + self.args = args + self.api_url = api_url + self.model = model + self.num_prompts = args.num_prompts + + def __len__(self) -> int: + return self.num_prompts + + def __getitem__(self, idx: int) -> RequestFuncInput: + return RequestFuncInput( + prompt=f"Random prompt {idx} for benchmarking diffusion models", + api_url=self.api_url, + model=self.model, + width=self.args.width, + height=self.args.height, + num_frames=self.args.num_frames, + num_inference_steps=self.args.num_inference_steps, + seed=self.args.seed, + fps=self.args.fps, + ) + + def get_requests(self) -> list[RequestFuncInput]: + return [self[i] for i in range(len(self))] + + +def _compute_expected_latency_ms_from_base(req: RequestFuncInput, args, base_time_ms: float | None) -> float | None: + """Compute expected execution time (ms) based on a base per-step-per-frame unit time. + + Assumes linear scaling with pixel area, frame count, and num_inference_steps. + The base unit represents latency for a 16x16 resolution, single frame, single step. + """ + + if base_time_ms is None: + return None + + width = req.width if req.width is not None else args.width + height = req.height if req.height is not None else args.height + if width is None or height is None: + return None + + frames = req.num_frames if req.num_frames is not None else args.num_frames + steps = req.num_inference_steps if req.num_inference_steps is not None else args.num_inference_steps + + frame_scale = frames if isinstance(frames, int) and frames > 0 else 1 + step_scale = steps if isinstance(steps, int) and steps > 0 else 1 + + area_units = max((float(width) * float(height)) / float(16 * 16), 1.0) + return float(base_time_ms) * area_units * frame_scale * step_scale + + +def _infer_slo_base_time_ms_from_warmups( + warmup_pairs: list[tuple[RequestFuncInput, RequestFuncOutput]], + args, +) -> float | None: + """Infer base SLO unit time from warmup requests. + + Returns the median base latency (ms) for a 16x16 resolution, single-frame, + single-step request. Only uses warmups that succeeded and have resolvable + width/height. + """ + + candidates_ms: list[float] = [] + for req, out in warmup_pairs: + if not out.success or out.latency <= 0: + continue + + width = req.width if req.width is not None else args.width + height = req.height if req.height is not None else args.height + if width is None or height is None: + continue + + frames = req.num_frames if req.num_frames is not None else args.num_frames + steps = req.num_inference_steps if req.num_inference_steps is not None else args.num_inference_steps + + frame_scale = int(frames) if isinstance(frames, int) and frames > 0 else 1 + step_scale = int(steps) if isinstance(steps, int) and steps > 0 else 1 + + area_units = max((float(width) * float(height)) / float(16 * 16), 1.0) + denom = area_units * float(frame_scale) * float(step_scale) + if denom <= 0: + continue + + candidates_ms.append((out.latency * 1000.0) / denom) + + if not candidates_ms: + return None + return float(np.median(candidates_ms)) + + +def _populate_slo_ms_from_warmups( + requests_list: list[RequestFuncInput], + warmup_pairs: list[tuple[RequestFuncInput, RequestFuncOutput]], + args, +) -> list[RequestFuncInput]: + """Populate missing RequestFuncInput.slo_ms using warmup outputs. + + - If a request already has slo_ms (e.g., trace-provided), it is kept as-is. + - If any request has slo_ms is None and we can infer base time from warmups, + we estimate each missing request's expected execution time and set: + req.slo_ms = expected_latency_ms * args.slo_scale + + Returns updated requests_list. + """ + + if not any(req.slo_ms is None for req in requests_list): + return requests_list + + base_time_ms = _infer_slo_base_time_ms_from_warmups(warmup_pairs, args) + if base_time_ms is None: + return requests_list + + slo_scale = float(getattr(args, "slo_scale", 3.0)) + if slo_scale <= 0: + raise ValueError(f"slo_scale must be positive, got {slo_scale}.") + + updated: list[RequestFuncInput] = [] + for req in requests_list: + if req.slo_ms is not None: + updated.append(req) + continue + expected_ms = _compute_expected_latency_ms_from_base(req, args, base_time_ms) + updated.append(replace(req, slo_ms=(expected_ms * slo_scale) if expected_ms is not None else None)) + + return updated + + +async def iter_requests( + requests_list: list[RequestFuncInput], + request_rate: float, +) -> AsyncGenerator[RequestFuncInput, None]: + """Yield requests using a fixed interval if request_rate is set. + + - If request_rate is inf, all requests are yielded immediately (no sleep). + - Otherwise, requests are emitted at a fixed cadence of 1 / request_rate seconds. + """ + + if request_rate != float("inf"): + if request_rate <= 0: + raise ValueError(f"request_rate must be positive or inf, got {request_rate}.") + interval_s = 1.0 / float(request_rate) + + for i, req in enumerate(requests_list): + if request_rate != float("inf") and i > 0: + await asyncio.sleep(interval_s) + yield req + + +def _guess_mime_type(path: str) -> str: + mime, _ = mimetypes.guess_type(path) + return mime or "application/octet-stream" + + +def _encode_image_as_data_url(path: str) -> str: + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("utf-8") + mime = _guess_mime_type(path) + return f"data:{mime};base64,{encoded}" + + +async def async_request_chat_completions( + input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + output = RequestFuncOutput() + output.start_time = time.perf_counter() + + extra_body = dict(input.extra_body) + if input.width and input.height: + extra_body.setdefault("height", input.height) + extra_body.setdefault("width", input.width) + if input.num_frames: + extra_body.setdefault("num_frames", input.num_frames) + if input.num_inference_steps: + extra_body.setdefault("num_inference_steps", input.num_inference_steps) + if input.seed is not None: + extra_body.setdefault("seed", input.seed) + if input.fps: + extra_body.setdefault("fps", input.fps) + + if input.image_paths and len(input.image_paths) > 0: + content = [] + if input.prompt: + content.append({"type": "text", "text": input.prompt}) + for img_path in input.image_paths: + if not os.path.exists(img_path): + output.error = f"Image file not found: {img_path}" + output.success = False + if pbar: + pbar.update(1) + return output + content.append( + { + "type": "image_url", + "image_url": {"url": _encode_image_as_data_url(img_path)}, + } + ) + messages = [{"role": "user", "content": content}] + else: + messages = [{"role": "user", "content": input.prompt}] + + payload = { + "model": input.model, + "messages": messages, + } + if extra_body: + payload["extra_body"] = extra_body + + try: + async with session.post(input.api_url, json=payload) as response: + if response.status == 200: + resp_json = await response.json() + output.response_body = resp_json + output.success = True + if "peak_memory_mb" in resp_json: + output.peak_memory_mb = resp_json["peak_memory_mb"] + else: + output.error = f"HTTP {response.status}: {await response.text()}" + output.success = False + except Exception as e: + output.error = str(e) + output.success = False + + output.latency = time.perf_counter() - output.start_time + + if output.success and input.slo_ms is not None: + output.slo_achieved = (output.latency * 1000.0) <= float(input.slo_ms) + + if pbar: + pbar.update(1) + return output + + +def calculate_metrics( + outputs: list[RequestFuncOutput], + total_duration: float, + requests_list: list[RequestFuncInput], + args, + slo_enabled: bool, +): + success_outputs = [o for o in outputs if o.success] + error_outputs = [o for o in outputs if not o.success] + + num_success = len(success_outputs) + latencies = [o.latency for o in success_outputs] + peak_memories = [o.peak_memory_mb for o in success_outputs if o.peak_memory_mb > 0] + + metrics = { + "duration": total_duration, + "completed_requests": num_success, + "failed_requests": len(error_outputs), + "throughput_qps": num_success / total_duration if total_duration > 0 else 0, + "latency_mean": np.mean(latencies) if latencies else 0, + "latency_median": np.median(latencies) if latencies else 0, + "latency_p99": np.percentile(latencies, 99) if latencies else 0, + "latency_p50": np.percentile(latencies, 50) if latencies else 0, + "peak_memory_mb_max": max(peak_memories) if peak_memories else 0, + "peak_memory_mb_mean": np.mean(peak_memories) if peak_memories else 0, + "peak_memory_mb_median": np.median(peak_memories) if peak_memories else 0, + } + + if slo_enabled: + slo_defined_total = 0 + slo_met_success = 0 + + for req, out in zip(requests_list, outputs): + if req.slo_ms is None: + continue + slo_defined_total += 1 + if out.slo_achieved is None: + continue + if out.slo_achieved: + slo_met_success += 1 + + slo_attain_all = (slo_met_success / slo_defined_total) if slo_defined_total > 0 else 0.0 + + metrics.update( + { + "slo_attainment_rate": slo_attain_all, + "slo_met_success": slo_met_success, + "slo_scale": getattr(args, "slo_scale", 3.0), + } + ) + + return metrics + + +def wait_for_service(base_url: str, timeout: int = 120) -> None: + print(f"Waiting for service at {base_url}...") + start_time = time.time() + while True: + try: + # Try /health endpoint first + resp = requests.get(f"{base_url}/health", timeout=1) + if resp.status_code == 200: + print("Service is ready.") + break + except requests.exceptions.RequestException: + pass + + if time.time() - start_time > timeout: + raise TimeoutError(f"Service at {base_url} did not start within {timeout} seconds.") + + time.sleep(1) + + +async def benchmark(args): + # Construct base_url if not provided + if args.base_url is None: + args.base_url = f"http://{args.host}:{args.port}" + + # Setup dataset (vLLM-Omni supports diffusion via /v1/chat/completions) + api_url = f"{args.base_url}/v1/chat/completions" + request_func = async_request_chat_completions + + if args.dataset == "vbench": + dataset = VBenchDataset(args, api_url, args.model) + elif args.dataset == "trace": + dataset = TraceDataset(args, api_url, args.model) + elif args.dataset == "random": + dataset = RandomDataset(args, api_url, args.model) + else: + raise ValueError(f"Unknown dataset: {args.dataset}") + + print("Loading requests...") + requests_list = dataset.get_requests() + print(f"Prepared {len(requests_list)} requests from {args.dataset} dataset.") + + # Limit concurrency + if args.max_concurrency is not None: + semaphore = asyncio.Semaphore(args.max_concurrency) + else: + semaphore = None + + async def limited_request_func(req, session, pbar): + if semaphore: + async with semaphore: + return await request_func(req, session, pbar) + else: + return await request_func(req, session, pbar) + + # Run benchmark + pbar = tqdm(total=len(requests_list), disable=args.disable_tqdm) + + async with aiohttp.ClientSession() as session: + warmup_pairs: list[tuple[RequestFuncInput, RequestFuncOutput]] = [] + if args.warmup_requests and requests_list: + print( + f"Running {args.warmup_requests} warmup request(s) \ + with num_inference_steps={args.warmup_num_inference_steps}..." + ) + for i in range(args.warmup_requests): + warm_req = requests_list[i % len(requests_list)] + if args.warmup_num_inference_steps is not None: + warm_req = replace( + warm_req, + num_inference_steps=args.warmup_num_inference_steps, + ) + warm_out = await limited_request_func(warm_req, session, None) + warmup_pairs.append((warm_req, warm_out)) + + if args.slo: + # Prefer trace-provided per-request slo_ms. Only populate when missing. + requests_list = _populate_slo_ms_from_warmups( + requests_list=requests_list, + warmup_pairs=warmup_pairs, + args=args, + ) + + start_time = time.perf_counter() + tasks = [] + async for req in iter_requests(requests_list=requests_list, request_rate=args.request_rate): + task = asyncio.create_task(limited_request_func(req, session, pbar)) + tasks.append(task) + + outputs = await asyncio.gather(*tasks) + total_duration = time.perf_counter() - start_time + + pbar.close() + + # Calculate metrics + metrics = calculate_metrics(outputs, total_duration, requests_list, args, args.slo) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=60, c="=")) + + # Section 1: Configuration + print("{:<40} {:<15}".format("Model:", args.model)) + print("{:<40} {:<15}".format("Dataset:", args.dataset)) + print("{:<40} {:<15}".format("Task:", args.task)) + + # Section 2: Execution & Traffic + print(f"{'-' * 50}") + print("{:<40} {:<15.2f}".format("Benchmark duration (s):", metrics["duration"])) + print("{:<40} {:<15}".format("Request rate:", str(args.request_rate))) + print( + "{:<40} {:<15}".format( + "Max request concurrency:", + str(args.max_concurrency) if args.max_concurrency else "not set", + ) + ) + print("{:<40} {}/{:<15}".format("Successful requests:", metrics["completed_requests"], len(requests_list))) + + # Section 3: Performance Metrics + print(f"{'-' * 50}") + + print("{:<40} {:<15.2f}".format("Request throughput (req/s):", metrics["throughput_qps"])) + print("{:<40} {:<15.4f}".format("Latency Mean (s):", metrics["latency_mean"])) + print("{:<40} {:<15.4f}".format("Latency Median (s):", metrics["latency_median"])) + print("{:<40} {:<15.4f}".format("Latency P99 (s):", metrics["latency_p99"])) + + if args.slo: + print(f"{'-' * 50}") + print("{:<40} {:<15.2%}".format("SLO Attainment Rate (all):", metrics.get("slo_attainment_rate", 0.0))) + print("{:<40} {:<15}".format("SLO Met (success count):", str(metrics.get("slo_met_success", 0)))) + print("{:<40} {:<15}".format("SLO Scale:", str(metrics.get("slo_scale", 3.0)))) + + if metrics["peak_memory_mb_max"] > 0: + print(f"{'-' * 50}") + print("{:<40} {:<15.2f}".format("Peak Memory Max (MB):", metrics["peak_memory_mb_max"])) + print("{:<40} {:<15.2f}".format("Peak Memory Mean (MB):", metrics["peak_memory_mb_mean"])) + print("{:<40} {:<15.2f}".format("Peak Memory Median (MB):", metrics["peak_memory_mb_median"])) + + print("\n" + "=" * 60) + + if args.output_file: + with open(args.output_file, "w") as f: + json.dump(metrics, f, indent=2) + print(f"Metrics saved to {args.output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark serving for diffusion models.") + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Base URL of the server (e.g., http://localhost:8091). Overrides host/port.", + ) + parser.add_argument("--host", type=str, default="localhost", help="Server host.") + parser.add_argument("--port", type=int, default=8091, help="Server port.") + parser.add_argument("--model", type=str, default="default", help="Model name.") + parser.add_argument( + "--dataset", + type=str, + default="vbench", + choices=["vbench", "trace", "random"], + help="Dataset to use.", + ) + parser.add_argument( + "--task", + type=str, + default="t2v", + choices=["t2v", "i2v", "ti2v", "ti2i", "i2i", "t2i"], + help="Task type.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to local dataset file (optional).", + ) + parser.add_argument("--num-prompts", type=int, default=10, help="Number of prompts to benchmark.") + parser.add_argument( + "--max-concurrency", + type=int, + default=1, + help="Maximum number of concurrent requests, default to `1`. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + ) + parser.add_argument( + "--warmup-requests", + type=int, + default=1, + help="Number of warmup requests to run before measurement.", + ) + parser.add_argument( + "--warmup-num-inference-steps", + type=int, + default=1, + help="num_inference_steps used for warmup requests.", + ) + parser.add_argument("--width", type=int, default=None, help="Image/Video width.") + parser.add_argument("--height", type=int, default=None, help="Image/Video height.") + parser.add_argument("--num-frames", type=int, default=None, help="Number of frames (for video).") + parser.add_argument( + "--num-inference-steps", + type=int, + default=50, + help="Number of inference steps (for diffusion models).", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed (for diffusion models).", + ) + parser.add_argument("--fps", type=int, default=None, help="FPS (for video).") + parser.add_argument("--output-file", type=str, default=None, help="Output JSON file for metrics.") + parser.add_argument( + "--slo", + action="store_true", + help=( + "Enable SLO calculation and reporting. If trace provides per-request slo_ms, it is used. " + "Otherwise, warmup request(s) are used to infer expected execution time assuming linear " + "scaling by resolution, frames, and steps, then slo_ms = expected_time * --slo-scale." + ), + ) + parser.add_argument( + "--slo-scale", + type=float, + default=3.0, + help="SLO target multiplier: slo_ms = estimated_exec_time_ms * slo_scale (default: 3).", + ) + parser.add_argument("--disable-tqdm", action="store_true", help="Disable progress bar.") + + args = parser.parse_args() + + asyncio.run(benchmark(args)) diff --git a/benchmarks/qwen3-omni/README.md b/benchmarks/qwen3-omni/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fdc3854451f810d0d9fc7cdc0e8b278d1b519028 --- /dev/null +++ b/benchmarks/qwen3-omni/README.md @@ -0,0 +1,86 @@ +# Benchmarks Guide + +This README explains how to (1) prepare benchmark datasets and (2) run the provided Qwen3-Omni benchmarks. + +## 1) Prepare the dataset (SeedTTS top100) + +```bash +cd benchmarks/build_dataset +pip install gdown + +# Download SeedTTS test set from Google Drive +gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP + +# Extract +tar -xf seedtts_testset.tar + +# Copy metadata and extract top-100 prompts +cp seedtts_testset/en/meta.lst meta.lst +python extract_prompts.py -i meta.lst -o top100.txt -n 100 + +# (Optional) clean up to save space +rm -rf seedtts_testset seedtts_testset.tar meta.lst +``` + +Artifacts: +- `benchmarks/build_dataset/top100.txt` — 100 text prompts (one per line). + +## 2) Run benchmarks + +All commands assume repo root (`vllm-omni`). + +### A. Transformers benchmark (offline, HF Transformers) + +``` +bash benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh +``` + +What it does: +- Runs `qwen3_omni_moe_transformers.py` over `top100.txt` with `--num_prompts 100`. +- Outputs to `benchmarks/qwen3-omni/transformers/benchmark_results/`: + - `perf_stats.json` — aggregated & per-prompt TPS/latency (thinker/talker/code2wav/overall). + - `results.json` — per-prompt outputs and audio paths. + - `audio/` — ~100 generated `.wav` files. + +Key checks: +- `overall_tps` and `*_tps_avg` should be non-zero and reasonably stable. +- Investigate any 0/NaN or unusually low TPS / long-tail latency. + +### B. vLLM Omni end-to-end benchmark (pipeline) + +``` +bash benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh +``` + +What it does: +- Runs `examples/offline_inference/qwen3_omni/end2end.py` with `--enable-stats`. +- Uses `benchmarks/build_dataset/top100.txt` and writes to: + - Logs: `benchmarks/qwen3-omni/vllm_omni/logs/` + - `omni_llm_pipeline_text.orchestrator.stats.jsonl` — per-stage latency stats. + - `omni_llm_pipeline_text.overall.stats.jsonl` — end-to-end latency/TPS. + - `omni_llm_pipeline_text.stage{0,1,2}.log` — per-stage detailed logs/errors. + - Outputs: `benchmarks/qwen3-omni/vllm_omni/outputs/` — ~100 text and `.wav` files. + +Key checks: +- Overall stats: end-to-end latency/TPS should be reasonable. +- Orchestrator stats: per-stage latency should be stable; investigate long tails. +- Stage logs: ensure no errors and no unusually slow stages. + + +## Performance snapshot + +The chart below summarizes our measured Qwen3-Omni MoE end-to-end benchmark, comparing vLLM-Omni against HF Transformers. It shows the overall throughput advantage for vLLM-Omni. These are actual experiment results—please refer to this performance when evaluating or reproducing the benchmark. + +![vLLM-Omni vs HF](./vllm-omni-vs-hf.png) + +## Directory layout +- `benchmarks/build_dataset/` — dataset prep utilities (e.g., SeedTTS top100). +- `benchmarks//vllm_omni/` — vLLM-Omni pipeline benchmarks, logs, outputs. +- Add new tasks under `benchmarks//...` with the same pattern: `transformers/`, `vllm_omni/`, task-specific README, and (optionally) dataset prep notes. +- `benchmarks//vllm-omni-vs-hf.png` — current performance snapshot (overall throughput comparison). +- `benchmarks//transformers/` — HF Transformers benchmarks (offline reference). + +## Troubleshooting +- Make sure GPU/driver/FlashAttention2 requirements are met for the chosen model. +- If downloads fail, confirm network access to Google Drive (`gdown`) and Hugging Face. +- If audio files are missing, check for errors in stage logs or model generation.*** diff --git a/benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh b/benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh new file mode 100644 index 0000000000000000000000000000000000000000..2679adf4f8bc349ae7ce6b2acc2d8a0d6bc98644 --- /dev/null +++ b/benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Qwen3-Omni Transformers Benchmark Evaluation Script +# This script must be run from the vllm-omni root directory + +# Get the directory where this script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Navigate to vllm-omni root directory (4 levels up from script location) +VLLM_OMNI_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +cd "$VLLM_OMNI_ROOT" || { echo "Error: Failed to navigate to vllm-omni directory"; exit 1; } + +echo "Working directory: $(pwd)" +# Verify we're in the correct directory and run benchmark +if [[ ! -f "benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py" ]]; then + echo "Error: Not in vllm-omni root directory. Please run from vllm-omni folder." +else + cd benchmarks/qwen3-omni/transformers + + python qwen3_omni_moe_transformers.py --prompts_file ../../build_dataset/top100.txt --num_prompts 100 + + echo "Logs and outputs are saved to $(pwd)/benchmark_results:" + echo " - perf_stats.json Aggregated/per-prompt TPS and latency (thinker/talker/code2wav/overall)" + echo " - results.json Per-prompt outputs and audio paths" + echo " - audio/ Generated wav files, there should be 100 wav file generated" + echo "Key checks: overall_tps and *_tps_avg should be non-zero and stable; investigate 0/NaN or unusually low TPS/long-tail latency." +fi diff --git a/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_model.py b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_model.py new file mode 100644 index 0000000000000000000000000000000000000000..43b56f3e9954590763e79e276e96bc1f3dfb7d9a --- /dev/null +++ b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_model.py @@ -0,0 +1,265 @@ +import time + +import torch +from transformers import Qwen3OmniMoeForConditionalGeneration + + +class Qwen3OmniMoeForConditionalGenerationWithLogging(Qwen3OmniMoeForConditionalGeneration): + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor | None = None, + speaker: str = "Ethan", + use_audio_in_video: bool = False, + return_audio: bool | None = None, + thinker_max_new_tokens: int = 1024, + thinker_eos_token_id: int = 151645, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 50, + talker_top_p: float = 1.0, + talker_temperature: float = 0.9, + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + total_t0 = time.time() + perf_stats = { + "thinker_tokens": 0, + "thinker_time_s": 0.0, + "thinker_tps": 0.0, + "talker_tokens": 0, + "talker_time_s": 0.0, + "talker_tps": 0.0, + "code2wav_tokens": 0, + "code2wav_time_s": 0.0, + "code2wav_tps": 0.0, + "total_tokens": 0, + "total_time_s": 0.0, + "total_tps": 0.0, + } + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initialized. " + "Use `enable_talker` method or set enable_talker in config " + "to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + "eos_token_id": thinker_eos_token_id, + } + + talker_kwargs = {} + token2wav_kwargs = {} + if return_audio: + speaker_id = self.config.talker_config.speaker_id.get(speaker.lower()) + if speaker_id is None: + raise NotImplementedError(f"Speaker {speaker} not implemented") + if input_ids.shape[0] != 1: + raise NotImplementedError("Qwen3-Omni currently does not support batched inference with audio output") + talker_suppressed_tokens = [ + i + for i in range( + self.config.talker_config.text_config.vocab_size - 1024, + self.config.talker_config.text_config.vocab_size, + ) + if i != self.config.talker_config.codec_eos_token_id + ] # Suppress additional special tokens, should not be predicted + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": self.config.talker_config.codec_eos_token_id, + "repetition_penalty": talker_repetition_penalty, + "suppress_tokens": talker_suppressed_tokens, + "output_hidden_states": True, + "return_dict_in_generate": True, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key in ("input_features", "attention_mask"): + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs and key in ["image_grid_thw", "video_grid_thw", "video_second_per_grid"]: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + t0 = time.time() + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + t1 = time.time() + perf_stats["thinker_time_s"] = max(0.0, t1 - t0) + try: + prompt_len = int(input_ids.shape[1]) if input_ids is not None else 0 + total_len = int(thinker_result.sequences.shape[-1]) + thinker_out_len = max(0, total_len - prompt_len) + except Exception: + thinker_out_len = 0 + perf_stats["thinker_tokens"] = thinker_out_len + perf_stats["thinker_tps"] = ( + (thinker_out_len / perf_stats["thinker_time_s"]) if perf_stats["thinker_time_s"] > 0 else 0.0 + ) + + if not generate_audio: + perf_stats["total_tokens"] = perf_stats["thinker_tokens"] + perf_stats["total_time_s"] = time.time() - total_t0 + perf_stats["total_tps"] = ( + (perf_stats["total_tokens"] / perf_stats["total_time_s"]) if perf_stats["total_time_s"] > 0 else 0.0 + ) + # attach stats to self + setattr(self, "_perf_stats_last", perf_stats) + if not hasattr(self, "_perf_stats_history"): + setattr(self, "_perf_stats_history", []) + self._perf_stats_history.append(perf_stats) + return thinker_result, None + + # 2. Prepare talker input + thinker_embed = torch.cat([hidden_states[0] for hidden_states in thinker_result.hidden_states], dim=1).to( + self.talker.device + ) # [1 t d] + thinker_hidden = torch.cat( + [ + hidden_states[self.config.talker_config.accept_hidden_layer] + for hidden_states in thinker_result.hidden_states + ], + dim=1, + ).to(self.talker.device) # [1 t d] + + im_start_indexes = torch.cat( + ( + torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(), + torch.tensor([thinker_result.sequences.shape[-1]], device=input_ids.device, dtype=input_ids.dtype), + ), + dim=-1, + ).to(self.talker.device) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here. + multimodal_mask = ( + (thinker_result.sequences == self.config.thinker_config.audio_token_id) | + (thinker_result.sequences == self.config.thinker_config.image_token_id) | + (thinker_result.sequences == self.config.thinker_config.video_token_id) + ).to(self.talker.device) # [1 t] # fmt: skip + + talker_special_tokens = torch.tensor( + [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], + device=self.thinker.device, + dtype=input_ids.dtype, + ) + tts_bos_embed, tts_eos_embed, tts_pad_embed = ( + self.talker.text_projection(self.thinker.get_input_embeddings()(talker_special_tokens)) + .to(self.talker.device) + .chunk(3, dim=1) + ) # 3 * [1 1 d] + + talker_input_embeds = [] # [1 t d] + talker_input_ids = [] + # For every chatml parts + for i in range(len(im_start_indexes) - 1): + im_start_index = im_start_indexes[i] + segment_end_index = im_start_indexes[i + 1] + role_token = input_ids[0][im_start_index + 1] + # Talker should ignore thinker system prompt + if role_token == self.config.system_token_id: + continue + # Talker takes word embeddings for tokens and hidden state from `accept_hidden_layer` for multimodal inputs + elif role_token == self.config.user_token_id: + talker_user_part = self._get_talker_user_parts( + im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed + ) + talker_input_embeds.append(talker_user_part) + talker_input_ids.append(thinker_result.sequences[:, im_start_index:segment_end_index]) + # Take assistant output (for now) + elif role_token == self.config.assistant_token_id and i == len(im_start_indexes) - 2: + talker_assistant_embeds, talker_assistant_ids, trailing_text_hidden = self._get_talker_assistant_parts( + im_start_index, + segment_end_index, + speaker_id, + thinker_embed, + tts_pad_embed, + tts_bos_embed, + tts_eos_embed, + ) + talker_input_embeds.append(talker_assistant_embeds) + talker_input_ids.append(talker_assistant_ids) + # History assistant output (ignore for now) + elif role_token == self.config.assistant_token_id and i != len(im_start_indexes) - 2: + continue + else: + raise AssertionError("Expect role id after <|im_start|> (assistant, user, system)") + talker_input_embed = torch.cat([embed.to(self.talker.device) for embed in talker_input_embeds], dim=1) + talker_input_id = torch.cat([embed.to(self.talker.device) for embed in talker_input_ids], dim=1) + t2 = time.time() + talker_result = self.talker.generate( + inputs_embeds=talker_input_embed, + trailing_text_hidden=trailing_text_hidden, + tts_pad_embed=tts_pad_embed, + talker_input_ids=talker_input_id, # Not use input_ids to prevent repetition penalty out of bound + **talker_kwargs, + ) + t3 = time.time() + perf_stats["talker_time_s"] = max(0.0, t3 - t2) + talker_codes = ( + torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1) + .transpose(1, 2) + .to(self.code2wav.device) + ) + try: + # codes shape: (B, num_quantizers, T). We log T as token length. + perf_stats["talker_tokens"] = int(talker_codes.shape[-1]) + except Exception: + perf_stats["talker_tokens"] = 0 + perf_stats["talker_tps"] = ( + (perf_stats["talker_tokens"] / perf_stats["talker_time_s"]) if perf_stats["talker_time_s"] > 0 else 0.0 + ) + t4 = time.time() + talker_wavs = self.code2wav.chunked_decode(talker_codes, chunk_size=300, left_context_size=25).float() + t5 = time.time() + perf_stats["code2wav_time_s"] = max(0.0, t5 - t4) + perf_stats["code2wav_tokens"] = perf_stats["talker_tokens"] # same T, not times 16 + perf_stats["code2wav_tps"] = ( + (perf_stats["code2wav_tokens"] / perf_stats["code2wav_time_s"]) + if perf_stats["code2wav_time_s"] > 0 + else 0.0 + ) + perf_stats["total_tokens"] = perf_stats["thinker_tokens"] + perf_stats["talker_tokens"] + perf_stats["total_time_s"] = time.time() - total_t0 + perf_stats["total_tps"] = ( + (perf_stats["total_tokens"] / perf_stats["total_time_s"]) if perf_stats["total_time_s"] > 0 else 0.0 + ) + setattr(self, "_perf_stats_last", perf_stats) + if not hasattr(self, "_perf_stats_history"): + setattr(self, "_perf_stats_history", []) + self._perf_stats_history.append(perf_stats) + return thinker_result, talker_wavs.float() + + +__all__ = [ + "Qwen3OmniMoeForConditionalGenerationWithLogging", +] diff --git a/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..68ab2a1e39da8904556960fbdbb1f702ce44f82e --- /dev/null +++ b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py @@ -0,0 +1,275 @@ +import argparse +import json +import os + +import soundfile as sf +from qwen3_omni_moe_model import Qwen3OmniMoeForConditionalGenerationWithLogging +from qwen_omni_utils import process_mm_info +from tqdm import tqdm +from transformers import Qwen3OmniMoeProcessor + +MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct" +# MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Thinking" + + +def load_prompts(prompts_file: str) -> list[str]: + """Load prompts from a text file, one prompt per line.""" + prompts = [] + with open(prompts_file, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + prompts.append(line) + return prompts + + +def run_benchmark( + model, + processor, + prompts: list[str], + output_dir: str = "benchmark_results", + speaker: str = "Ethan", + use_audio_in_video: bool = True, +): + """ + Run benchmark on a list of prompts and collect performance stats. + + Args: + model: The Qwen3OmniMoe model + processor: The Qwen3OmniMoe processor + prompts: List of text prompts to process + output_dir: Directory to save results + speaker: Speaker voice for audio output + use_audio_in_video: Whether to use audio in video + + Returns: + tuple: (aggregated_stats, results, audio_outputs) + - aggregated_stats: dict with aggregated performance statistics + - results: list of dicts with per-prompt results + - audio_outputs: list of audio tensors/arrays (or None if no audio) + """ + os.makedirs(output_dir, exist_ok=True) + audio_dir = os.path.join(output_dir, "audio") + os.makedirs(audio_dir, exist_ok=True) + + all_stats = [] + results = [] + audio_outputs = [] + + for idx, prompt in enumerate(tqdm(prompts, desc="Processing prompts")): + conversation = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + }, + ] + + # Preparation for inference + text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + audios, images, videos = process_mm_info(conversation, use_audio_in_video=use_audio_in_video) + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + inputs = inputs.to(model.device).to(model.dtype) + + # Inference: Generation of the output text and audio + text_ids, audio = model.generate( + **inputs, speaker=speaker, thinker_return_dict_in_generate=True, use_audio_in_video=use_audio_in_video + ) + + # Decode output text + output_text = processor.batch_decode( + text_ids.sequences[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0] + + # Collect performance stats + perf_stats = None + if hasattr(model, "_perf_stats_last"): + perf_stats = model._perf_stats_last.copy() + perf_stats["prompt_idx"] = idx + perf_stats["prompt"] = prompt + all_stats.append(perf_stats) + + # Save audio and collect audio output + audio_path = None + audio_data = None + if audio is not None: + audio_data = audio.reshape(-1).detach().cpu().numpy() + audio_path = os.path.join(audio_dir, f"output_{idx:04d}.wav") + sf.write( + audio_path, + audio_data, + samplerate=24000, + ) + audio_outputs.append(audio_data) + else: + audio_outputs.append(None) + + # Save result + result = { + "idx": idx, + "prompt": prompt, + "output": output_text, + "audio_path": audio_path, + "perf_stats": perf_stats, + } + results.append(result) + + # Aggregate statistics + aggregated_stats = aggregate_stats(all_stats) + + # Save all results + results_path = os.path.join(output_dir, "results.json") + with open(results_path, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + # Save aggregated stats + stats_path = os.path.join(output_dir, "perf_stats.json") + with open(stats_path, "w", encoding="utf-8") as f: + json.dump({"aggregated": aggregated_stats, "per_prompt": all_stats}, f, ensure_ascii=False, indent=2) + + # Count saved audio files + num_audio_saved = sum(1 for a in audio_outputs if a is not None) + print(f"\nSaved {num_audio_saved} audio files to {audio_dir}/") + + return aggregated_stats, results, audio_outputs + + +def aggregate_stats(all_stats: list[dict]) -> dict: + """Aggregate performance statistics from multiple runs.""" + if not all_stats: + return {} + + keys = [ + "thinker_tokens", + "thinker_time_s", + "thinker_tps", + "talker_tokens", + "talker_time_s", + "talker_tps", + "code2wav_tokens", + "code2wav_time_s", + "code2wav_tps", + "total_tokens", + "total_time_s", + "total_tps", + ] + + aggregated = { + "num_samples": len(all_stats), + } + + for key in keys: + values = [s.get(key, 0) for s in all_stats if key in s] + if values: + aggregated[f"{key}_sum"] = sum(values) + aggregated[f"{key}_avg"] = sum(values) / len(values) + aggregated[f"{key}_min"] = min(values) + aggregated[f"{key}_max"] = max(values) + + # Calculate overall throughput + total_tokens = aggregated.get("total_tokens_sum", 0) + total_time = aggregated.get("total_time_s_sum", 0) + if total_time > 0: + aggregated["overall_tps"] = total_tokens / total_time + + return aggregated + + +def print_stats(stats: dict): + """Print performance statistics in a formatted way.""" + print("\n" + "=" * 60) + print("Performance Statistics Summary") + print("=" * 60) + + print(f"\nNumber of samples: {stats.get('num_samples', 0)}") + + print("\n--- Thinker ---") + print(f" Total tokens: {stats.get('thinker_tokens_sum', 0):.0f}") + print(f" Total time: {stats.get('thinker_time_s_sum', 0):.2f}s") + print(f" Avg TPS: {stats.get('thinker_tps_avg', 0):.2f}") + print(f" Min TPS: {stats.get('thinker_tps_min', 0):.2f}") + print(f" Max TPS: {stats.get('thinker_tps_max', 0):.2f}") + + print("\n--- Talker ---") + print(f" Total tokens: {stats.get('talker_tokens_sum', 0):.0f}") + print(f" Total time: {stats.get('talker_time_s_sum', 0):.2f}s") + print(f" Avg TPS: {stats.get('talker_tps_avg', 0):.2f}") + print(f" Min TPS: {stats.get('talker_tps_min', 0):.2f}") + print(f" Max TPS: {stats.get('talker_tps_max', 0):.2f}") + + print("\n--- Code2Wav ---") + print(f" Total tokens: {stats.get('code2wav_tokens_sum', 0):.0f}") + print(f" Total time: {stats.get('code2wav_time_s_sum', 0):.2f}s") + print(f" Avg TPS: {stats.get('code2wav_tps_avg', 0):.2f}") + print(f" Min TPS: {stats.get('code2wav_tps_min', 0):.2f}") + print(f" Max TPS: {stats.get('code2wav_tps_max', 0):.2f}") + + print("\n--- Overall ---") + print(f" Total tokens: {stats.get('total_tokens_sum', 0):.0f}") + print(f" Total time: {stats.get('total_time_s_sum', 0):.2f}s") + print(f" Overall TPS: {stats.get('overall_tps', 0):.2f}") + print(f" Avg TPS: {stats.get('total_tps_avg', 0):.2f}") + print(f" Min TPS: {stats.get('total_tps_min', 0):.2f}") + print(f" Max TPS: {stats.get('total_tps_max', 0):.2f}") + + print("=" * 60 + "\n") + + +def main(): + parser = argparse.ArgumentParser(description="Qwen3-Omni Benchmark Script") + parser.add_argument( + "--prompts_file", + type=str, + default="benchmark/build_dataset/top100.txt", + help="Path to the prompts file (one prompt per line)", + ) + parser.add_argument( + "--output_dir", type=str, default="benchmark_results", help="Directory to save benchmark results" + ) + parser.add_argument("--model_path", type=str, default=MODEL_PATH, help="Path to the model") + parser.add_argument("--speaker", type=str, default="Ethan", help="Speaker voice for audio output") + parser.add_argument("--num_prompts", type=int, default=None, help="Number of prompts to process (default: all)") + args = parser.parse_args() + + # Load model and processor + print(f"Loading model from {args.model_path}...") + model = Qwen3OmniMoeForConditionalGenerationWithLogging.from_pretrained( + args.model_path, + dtype="auto", + device_map="auto", + attn_implementation="flash_attention_2", + ) + processor = Qwen3OmniMoeProcessor.from_pretrained(args.model_path) + + # Benchmark mode + print(f"Loading prompts from {args.prompts_file}...") + prompts = load_prompts(args.prompts_file) + + if args.num_prompts: + prompts = prompts[: args.num_prompts] + + print(f"Running benchmark on {len(prompts)} prompts...") + + aggregated_stats, results, audio_outputs = run_benchmark( + model=model, + processor=processor, + prompts=prompts, + output_dir=args.output_dir, + speaker=args.speaker, + ) + + print_stats(aggregated_stats) + print(f"\nResults saved to {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/qwen3-omni/vllm-omni-vs-hf.png b/benchmarks/qwen3-omni/vllm-omni-vs-hf.png new file mode 100644 index 0000000000000000000000000000000000000000..e47079335be717c3f1f296559745794b69dcea26 Binary files /dev/null and b/benchmarks/qwen3-omni/vllm-omni-vs-hf.png differ diff --git a/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh b/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh new file mode 100644 index 0000000000000000000000000000000000000000..61e46f8c3ea63dd756c055010d8df71f3c017b00 --- /dev/null +++ b/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Qwen3-Omni Benchmark Evaluation Script +# This script must be run from the vllm-omni root directory + +# Get the directory where this script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Navigate to vllm-omni root directory (4 levels up from script location) +VLLM_OMNI_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +cd "$VLLM_OMNI_ROOT" || { echo "Error: Failed to navigate to vllm-omni directory"; exit 1; } + +echo "Working directory: $(pwd)" + +# Verify we're in the correct directory and run benchmark +if [[ ! -d "benchmarks/qwen3-omni/vllm_omni" ]]; then + echo "Error: Not in vllm-omni root directory. Please run from vllm-omni folder." +else + log_dir=benchmarks/qwen3-omni/vllm_omni/logs + outputs_dir=benchmarks/qwen3-omni/vllm_omni/outputs + end2end_script_path=examples/offline_inference/qwen3_omni/end2end.py + build_dataset_path=benchmarks/build_dataset/top100.txt + + python $end2end_script_path --output-wav $outputs_dir \ + --query-type text \ + --txt-prompts $build_dataset_path \ + --enable-stats \ + --log-dir $log_dir + echo "Logs and outputs are saved in ${log_dir} and ${outputs_dir} respectively:" + echo " - omni_llm_pipeline_text run dir/base name" + echo " - omni_llm_pipeline_text.orchestrator.stats.jsonl orchestrator-stage latency stats" + echo " - omni_llm_pipeline_text.overall.stats.jsonl overall latency/TPS stats" + echo " - omni_llm_pipeline_text.stage0.log per-stage detailed logs" + echo " - omni_llm_pipeline_text.stage1.log" + echo " - omni_llm_pipeline_text.stage2.log" + echo "Key checks: overall.stats.jsonl for end-to-end latency/TPS; orchestrator.stats.jsonl for stable per-stage latency; stage*.log for errors or long tails." + echo " - outputs/ Generated txt and wav files, there should be 100 text and wav files generated respectively" +fi diff --git a/collect_env.py b/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..8b09379e1a33afdd9675b4c7e72c815f848dd8c7 --- /dev/null +++ b/collect_env.py @@ -0,0 +1,760 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa +# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py + +import datetime +import locale +import os +import subprocess +import sys + +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` +from collections import namedtuple + +import regex as re + +from vllm.envs import environment_variables + +try: + import torch + + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple( + "SystemEnv", + [ + "torch_version", + "is_debug_build", + "cuda_compiled_version", + "gcc_version", + "clang_version", + "cmake_version", + "os", + "libc_version", + "python_version", + "python_platform", + "is_cuda_available", + "cuda_runtime_version", + "cuda_module_loading", + "nvidia_driver_version", + "nvidia_gpu_models", + "cudnn_version", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + "conda_packages", + "hip_compiled_version", + "hip_runtime_version", + "miopen_runtime_version", + "caching_allocator_config", + "is_xnnpack_available", + "cpu_info", + "rocm_version", # vllm specific field + "vllm_version", # vllm specific field + "vllm_omni_version", # vllm-omni specific field + "vllm_build_flags", # vllm specific field + "gpu_topo", # vllm specific field + "env_vars", + ], +) + +DEFAULT_CONDA_PATTERNS = { + "torch", + "numpy", + "cudatoolkit", + "soumith", + "mkl", + "magma", + "triton", + "optree", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", + "flashinfer-python", +} + +DEFAULT_PIP_PATTERNS = { + "torch", + "numpy", + "mypy", + "flake8", + "triton", + "optree", + "onnx", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", + "flashinfer-python", +} + + +def run(command): + """Return (return-code, stdout, stderr).""" + shell = True if type(command) is str else False + try: + p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == "win32": + enc = "oem" + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + if command == "nvidia-smi topo -m": + # don't remove the leading whitespace of `nvidia-smi topo -m` + # because they are meaningful + output = output.rstrip() + else: + output = output.strip() + err = raw_err.decode(enc) + return rc, output, err.strip() + + except FileNotFoundError: + cmd_str = command if isinstance(command, str) else command[0] + return 127, "", f"Command not found: {cmd_str}" + + +def run_and_read_all(run_lambda, command): + """Run command using run_lambda; reads and returns entire output if rc is 0.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Run command using run_lambda, returns the first regex match if it exists.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +def run_and_return_first_line(run_lambda, command): + """Run command using run_lambda and returns first line if output is not empty.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out.split("\n")[0] + + +def get_conda_packages(run_lambda, patterns=None): + if patterns is None: + patterns = DEFAULT_CONDA_PATTERNS + conda = os.environ.get("CONDA_EXE", "conda") + out = run_and_read_all(run_lambda, [conda, "list"]) + if out is None: + return out + + return "\n".join( + line for line in out.splitlines() if not line.startswith("#") and any(name in line for name in patterns) + ) + + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") + + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, "clang --version", r"clang version (.*)") + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == "darwin": + cmd = "kextstat | grep -i cuda" + return run_and_parse_first_match(run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]") + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") + + +def get_gpu_info(run_lambda): + if get_platform() == "darwin" or ( + TORCH_AVAILABLE and hasattr(torch.version, "hip") and torch.version.hip is not None + ): + if TORCH_AVAILABLE and torch.cuda.is_available(): + if torch.version.hip is not None: + prop = torch.cuda.get_device_properties(0) + if hasattr(prop, "gcnArchName"): + gcnArch = " ({})".format(prop.gcnArchName) + else: + gcnArch = "NoGCNArchNameOnOldPyTorch" + else: + gcnArch = "" + return torch.cuda.get_device_name(None) + gcnArch + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r" \(UUID: .+?\)") + rc, out, _ = run_lambda(smi + " -L") + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, "", out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") + + +def get_cudnn_version(run_lambda): + """Return a list of libcudnn.so; it's hard to tell which one is being used.""" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") + where_cmd = os.path.join(system_root, "System32", "where") + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == "darwin": + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + l = os.environ.get("CUDNN_LIBRARY") + if l is not None and os.path.isfile(l): + return os.path.realpath(l) + return None + files_set = set() + for fn in out.split("\n"): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = sorted(files_set) + if len(files) == 1: + return files[0] + result = "\n".join(files) + return "Probably one of the following:\n{}".format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = "nvidia-smi" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") + legacy_path = os.path.join(program_files_root, "NVIDIA Corporation", "NVSMI", smi) + new_path = os.path.join(system_root, "System32", smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = '"{}"'.format(candidate_smi) + break + return smi + + +def get_rocm_version(run_lambda): + """Returns the ROCm version if available, otherwise 'N/A'.""" + return run_and_parse_first_match(run_lambda, "hipcc --version", r"HIP version: (\S+)") + + +def get_vllm_version(): + from vllm import __version__, __version_tuple__ + + if __version__ == "dev": + return "N/A (dev)" + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith("g"): + # it's a dev build + if "." in version_str: + # it's a dev build containing local changes + git_sha = version_str.split(".")[0][1:] + date = version_str.split(".")[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + # it's a dev build without local changes + git_sha = version_str[1:] # type: ignore + return f"{__version__} (git sha: {git_sha})" + return __version__ + + +def get_vllm_omni_version(run_lambda): + try: + import vllm_omni + from vllm_omni import __version__, __version_tuple__ + + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith("g"): + if "." in version_str: + git_sha = version_str.split(".")[0][1:] + date = version_str.split(".")[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + git_sha = version_str[1:] + return f"{__version__} (git sha: {git_sha})" + + package_dir = os.path.dirname(os.path.abspath(vllm_omni.__file__)) + git_sha = run_and_read_all(run_lambda, f"git -C {package_dir} rev-parse --short HEAD") + if git_sha: + return f"{__version__} (git sha: {git_sha})" + + return __version__ + except ImportError: + return "N/A (vllm_omni not installed)" + + +def summarize_vllm_build_flags(): + # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. + return "CUDA Archs: {}; ROCm: {}".format( + os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"), + "Enabled" if os.environ.get("ROCM_HOME") else "Disabled", + ) + + +def get_gpu_topo(run_lambda): + output = None + + if get_platform() == "linux": + output = run_and_read_all(run_lambda, "nvidia-smi topo -m") + if output is None: + output = run_and_read_all(run_lambda, "rocm-smi --showtopo") + + return output + + +def get_cpu_info(run_lambda): + rc, out, err = 0, "", "" + if get_platform() == "linux": + rc, out, err = run_lambda("lscpu") + elif get_platform() == "win32": + rc, out, err = run_lambda( + "wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE" + ) + elif get_platform() == "darwin": + rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") + cpu_info = "None" + if rc == 0: + cpu_info = out + else: + cpu_info = err + return cpu_info + + +def get_platform(): + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") + + +def get_windows_version(run_lambda): + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic") + findstr_cmd = os.path.join(system_root, "System32", "findstr") + return run_and_read_all(run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd)) + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, "lsb_release -a", r"Description:\t(.*)") + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + from platform import machine + + platform = get_platform() + + if platform == "win32" or platform == "cygwin": + return get_windows_version(run_lambda) + + if platform == "darwin": + version = get_mac_version(run_lambda) + if version is None: + return None + return "macOS {} ({})".format(version, machine()) + + if platform == "linux": + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return "{} ({})".format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return "{} ({})".format(desc, machine()) + + return "{} ({})".format(platform, machine()) + + # Unknown platform + return platform + + +def get_python_platform(): + import platform + + return platform.platform() + + +def get_libc_version(): + import platform + + if get_platform() != "linux": + return "N/A" + return "-".join(platform.libc_ver()) + + +def is_uv_venv(): + if os.environ.get("UV"): + return True + pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg") + if os.path.exists(pyvenv_cfg_path): + with open(pyvenv_cfg_path, "r") as f: + return any(line.startswith("uv = ") for line in f) + return False + + +def get_pip_packages(run_lambda, patterns=None): + """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" + if patterns is None: + patterns = DEFAULT_PIP_PATTERNS + + def run_with_pip(): + try: + import importlib.util + + pip_spec = importlib.util.find_spec("pip") + pip_available = pip_spec is not None + except ImportError: + pip_available = False + + if pip_available: + cmd = [sys.executable, "-mpip", "list", "--format=freeze"] + elif is_uv_venv(): + print("uv is set") + cmd = ["uv", "pip", "list", "--format=freeze"] + else: + raise RuntimeError("Could not collect pip list output (pip or uv module not available)") + + out = run_and_read_all(run_lambda, cmd) + return "\n".join(line for line in out.splitlines() if any(name in line for name in patterns)) + + pip_version = "pip3" if sys.version[0] == "3" else "pip" + out = run_with_pip() + return pip_version, out + + +def get_cachingallocator_config(): + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + return ca_config + + +def get_cuda_module_loading_config(): + if TORCH_AVAILABLE and torch.cuda.is_available(): + torch.cuda.init() + config = os.environ.get("CUDA_MODULE_LOADING", "") + return config + else: + return "N/A" + + +def is_xnnpack_available(): + if TORCH_AVAILABLE: + import torch.backends.xnnpack + + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + else: + return "N/A" + + +def get_env_vars(): + env_vars = "" + secret_terms = ("secret", "token", "api", "access", "password") + report_prefix = ( + "TORCH", + "NCCL", + "PYTORCH", + "CUDA", + "CUBLAS", + "CUDNN", + "OMP_", + "MKL_", + "NVIDIA", + ) + for k, v in os.environ.items(): + if any(term in k.lower() for term in secret_terms): + continue + if k in environment_variables: + env_vars = env_vars + "{}={}".format(k, v) + "\n" + if k.startswith(report_prefix): + env_vars = env_vars + "{}={}".format(k, v) + "\n" + + return env_vars + + +def get_env_info(): + run_lambda = run + pip_version, pip_list_output = get_pip_packages(run_lambda) + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + if not hasattr(torch.version, "hip") or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" + else: # HIP version + + def get_version_or_na(cfg, prefix): + _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] + return _lst[0] if _lst else "N/A" + + cfg = torch._C._show_config().split("\n") + hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") + miopen_runtime_version = get_version_or_na(cfg, "MIOpen") + cuda_version_str = "N/A" + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A" + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" + + sys_version = sys.version.replace("\n", " ") + + conda_packages = get_conda_packages(run_lambda) + + rocm_version = get_rocm_version(run_lambda) + vllm_version = get_vllm_version() + vllm_omni_version = get_vllm_omni_version(run_lambda) + vllm_build_flags = summarize_vllm_build_flags() + gpu_topo = get_gpu_topo(run_lambda) + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version="{} ({}-bit runtime)".format(sys_version, sys.maxsize.bit_length() + 1), + python_platform=get_python_platform(), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + cuda_module_loading=get_cuda_module_loading_config(), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=conda_packages, + os=get_os(run_lambda), + libc_version=get_libc_version(), + gcc_version=get_gcc_version(run_lambda), + clang_version=get_clang_version(run_lambda), + cmake_version=get_cmake_version(run_lambda), + caching_allocator_config=get_cachingallocator_config(), + is_xnnpack_available=is_xnnpack_available(), + cpu_info=get_cpu_info(run_lambda), + rocm_version=rocm_version, + vllm_version=vllm_version, + vllm_omni_version=vllm_omni_version, + vllm_build_flags=vllm_build_flags, + gpu_topo=gpu_topo, + env_vars=get_env_vars(), + ) + + +env_info_fmt = """ +============================== + System Info +============================== +OS : {os} +GCC version : {gcc_version} +Clang version : {clang_version} +CMake version : {cmake_version} +Libc version : {libc_version} + +============================== + PyTorch Info +============================== +PyTorch version : {torch_version} +Is debug build : {is_debug_build} +CUDA used to build PyTorch : {cuda_compiled_version} +ROCM used to build PyTorch : {hip_compiled_version} + +============================== + Python Environment +============================== +Python version : {python_version} +Python platform : {python_platform} + +============================== + CUDA / GPU Info +============================== +Is CUDA available : {is_cuda_available} +CUDA runtime version : {cuda_runtime_version} +CUDA_MODULE_LOADING set to : {cuda_module_loading} +GPU models and configuration : {nvidia_gpu_models} +Nvidia driver version : {nvidia_driver_version} +cuDNN version : {cudnn_version} +HIP runtime version : {hip_runtime_version} +MIOpen runtime version : {miopen_runtime_version} +Is XNNPACK available : {is_xnnpack_available} + +============================== + CPU Info +============================== +{cpu_info} + +============================== +Versions of relevant libraries +============================== +{pip_packages} +{conda_packages} +""".strip() + +# both the above code and the following code use `strip()` to +# remove leading/trailing whitespaces, so we need to add a newline +# in between to separate the two sections +env_info_fmt += "\n\n" + +env_info_fmt += """ +============================== + vLLM Info +============================== +ROCM Version : {rocm_version} +vLLM Version : {vllm_version} +vLLM-Omni Version : {vllm_omni_version} +vLLM Build Flags: + {vllm_build_flags} +GPU Topology: + {gpu_topo} + +============================== + Environment Variables +============================== +{env_vars} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement="Could not collect"): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true="Yes", false="No"): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag="[prepend]"): + lines = text.split("\n") + updated_lines = [tag + line for line in lines] + return "\n".join(updated_lines) + + def replace_if_empty(text, replacement="No relevant packages"): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split("\n")) > 1: + return "\n{}\n".format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + "cuda_runtime_version", + "nvidia_gpu_models", + "nvidia_driver_version", + ] + all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] + all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = "No CUDA" + if envinfo.cuda_compiled_version is None: + mutable_dict["cuda_compiled_version"] = "None" + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) + mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend(mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)) + if mutable_dict["conda_packages"]: + mutable_dict["conda_packages"] = prepend(mutable_dict["conda_packages"], "[conda] ") + mutable_dict["cpu_info"] = envinfo.cpu_info + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + if TORCH_AVAILABLE and hasattr(torch, "utils") and hasattr(torch.utils, "_crash_handler"): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime("%Y-%m-%d %H:%M:%S") + msg = ( + "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + + "if this is related to your bug please include it when you file a report ***" + ) + print(msg, file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci new file mode 100644 index 0000000000000000000000000000000000000000..c5d84734584765d606b763d06ad9954e90dd86b4 --- /dev/null +++ b/docker/Dockerfile.ci @@ -0,0 +1,20 @@ +ARG VLLM_BASE_IMAGE=vllm/vllm-openai +ARG VLLM_BASE_TAG=v0.15.0 +FROM ${VLLM_BASE_IMAGE}:${VLLM_BASE_TAG} +ARG APP_DIR=/workspace/vllm-omni +WORKDIR ${APP_DIR} + +COPY . . + +# Install system dependencies +RUN apt-get update && \ + apt-get install -y ffmpeg && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install vllm-omni into the same uv-managed Python environment used by the base image. +RUN uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir ".[dev]" + +RUN ln -sf /usr/bin/python3 /usr/bin/python + +ENTRYPOINT [] diff --git a/docker/Dockerfile.ci.npu b/docker/Dockerfile.ci.npu new file mode 100644 index 0000000000000000000000000000000000000000..cdf7a70f3a66079cbc48a12ab4d1628c60e94429 --- /dev/null +++ b/docker/Dockerfile.ci.npu @@ -0,0 +1,15 @@ +ARG VLLM_ASCEND_IMAGE=quay.nju.edu.cn/ascend/vllm-ascend +ARG VLLM_ASCEND_TAG=v0.11.0rc2 +FROM ${VLLM_ASCEND_IMAGE}:${VLLM_ASCEND_TAG} + +ARG APP_DIR=/vllm-workspace/vllm-omni +WORKDIR ${APP_DIR} + +COPY . . + +# Install vllm-omni with dev dependencies +RUN pip install --no-cache-dir -e ".[dev]" + +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn + +ENTRYPOINT [] diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100644 index 0000000000000000000000000000000000000000..bbb75a196178a6b89be1842e9e398eda2a74b9d8 --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,32 @@ +ARG BASE_IMAGE=vllm/vllm-openai-rocm:v0.15.0 +FROM ${BASE_IMAGE} AS final + +ARG COMMON_WORKDIR=/app + +WORKDIR ${COMMON_WORKDIR} + +# Step 1: Setup - Install system dependencies +RUN apt-get update && \ + apt-get install -y ffmpeg && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN mkdir -p ${COMMON_WORKDIR}/vllm-omni + +# Step 2: Copy vllm-omni code and install without uv +COPY . ${COMMON_WORKDIR}/vllm-omni +RUN cd ${COMMON_WORKDIR}/vllm-omni && uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir ".[dev]" + +# When we are installing onnxruntime-rocm, we need to uninstall the system-installed onnxruntime first. +# These are the dependencies of Qwen3-TTS. +RUN uv pip uninstall onnxruntime --system && uv pip install --no-cache-dir onnxruntime-rocm sox --system + +RUN ln -sf /usr/bin/python3 /usr/bin/python + +CMD ["/bin/bash"] + +ENTRYPOINT [] + +#Set entrypoint for vllm-openai official images +FROM final AS vllm-openai +ENTRYPOINT ["vllm", "serve", "--omni"] diff --git a/docs/.nav.yml b/docs/.nav.yml new file mode 100644 index 0000000000000000000000000000000000000000..428fd16b570fcd7e8b1d7b10ea42529d03494bc3 --- /dev/null +++ b/docs/.nav.yml @@ -0,0 +1,70 @@ +nav: +- Home: README.md +- User Guide: + - Getting Started: + - getting_started/quickstart.md + - getting_started/installation/* + - Serving: + - OpenAI-Compatible API: + - Image Generation: serving/image_generation_api.md + - Image Edit: serving/image_edit_api.md + - Examples: + - examples/README.md + - Offline Inference: + - Image-To-Image: user_guide/examples/offline_inference/image_to_image.md + - Image-To-Video: user_guide/examples/offline_inference/image_to_video.md + - Qwen2.5-Omni: user_guide/examples/offline_inference/qwen2_5_omni.md + - Qwen3-Omni: user_guide/examples/offline_inference/qwen3_omni.md + - Qwen3-TTS Offline Inference: user_guide/examples/offline_inference/qwen3_tts.md + - Text-To-Image: user_guide/examples/offline_inference/text_to_image.md + - Text-To-Video: user_guide/examples/offline_inference/text_to_video.md + - Online Serving: + - Image-To-Image: user_guide/examples/online_serving/image_to_image.md + - Qwen2.5-Omni: user_guide/examples/online_serving/qwen2_5_omni.md + - Qwen3-Omni: user_guide/examples/online_serving/qwen3_omni.md + - Text-To-Image: user_guide/examples/online_serving/text_to_image.md + - General: + - usage/* + - Configuration: + - configuration/README.md + - configuration/* + - Models: + - models/supported_models.md + - Features: + - Sleep Mode: features/sleep_mode.md + - Diffusion Features: + - Overview: user_guide/diffusion_acceleration.md + - TeaCache: user_guide/diffusion/teacache.md + - Cache-DiT: user_guide/diffusion/cache_dit_acceleration.md + - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md + - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md +- Developer Guide: + - General: + - contributing/README.md + - glob: contributing/* + flatten_single_child_sections: true + - Model Implementation: + - contributing/model/README.md + - contributing/model/adding_omni_model.md + - contributing/model/adding_diffusion_model.md + - CI: contributing/ci + - Design Documents: + - design/index.md + - design/architecture_overview.md + - Feature Design: + - design/feature/disaggregated_inference.md + - design/feature/ray_based_execution.md + - Module Design: + - design/module/ar_module.md + - design/module/dit_module.md + - design/module/entrypoint_module.md + - Docs Guide: contributing/DOCS_GUIDE.md +- API Reference: + - api/README.md + - api/vllm_omni +- CLI Reference: cli +- Community: + - community/* + - Slack: https://slack.vllm.ai + - Blog: https://blog.vllm.ai + - Forum: https://discuss.vllm.ai diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..363f579c992804aacf8fdd12fad046ab62f27351 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,64 @@ +--- +hide: + - navigation + - toc +--- + +# Welcome to vLLM-Omni + +

+ + + vllm-omni + +

+

+Easy, fast, and cheap omni-modality model serving for everyone +

+ +

+ +Star +Watch +Fork +

+ + +## About + +[vLLM](https://github.com/vllm-project/vllm) was originally designed to support large language models for text-based autoregressive generation tasks. vLLM-Omni is a framework that extends its support for omni-modality model inference and serving: + +- **Omni-modality**: Text, image, video, and audio data processing +- **Non-autoregressive Architectures**: extend the AR support of vLLM to Diffusion Transformers (DiT) and other parallel generation models +- **Heterogeneous outputs**: from traditional text generation to multimodal outputs + +

+ + + vllm-omni-arch + +

+ +vLLM-Omni is fast with: + +- State-of-the-art AR support by leveraging efficient KV cache management from vLLM +- Pipelined stage execution overlapping for high throughput performance +- Fully disaggregation based on OmniConnector and dynamic resource allocation across stages + +vLLM-Omni is flexible and easy to use with: + +- Heterogeneous pipeline abstraction to manage complex model workflows +- Seamless integration with popular Hugging Face models +- Tensor, pipeline, data and expert parallelism support for distributed inference +- Streaming outputs +- OpenAI-compatible API server + +vLLM-Omni seamlessly supports most popular open-source models on HuggingFace, including: + +- Omni-modality models (e.g. Qwen2.5-Omni, Qwen3-Omni) +- Multi-modality generation models (e.g. Qwen-Image) + +For more information, checkout the following: + +- [vllm-omni architecture design and recent roadmaps](https://docs.google.com/presentation/d/1qv4qMW1rKAqDREMXiUDLIgqqHQe7TDPj/edit?usp=sharing&ouid=110473603432222024453&rtpof=true&sd=true) +- [vllm-omni announcement blogpost](https://blog.vllm.ai/2025/11/30/vllm-omni.html) diff --git a/docs/api/README.md b/docs/api/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3a110f491e643fe5f23fd8a38d3c97516d59776a --- /dev/null +++ b/docs/api/README.md @@ -0,0 +1,107 @@ +# Summary + +## Entry Points + +Main entry points for vLLM-Omni inference and serving. + +- [vllm_omni.entrypoints.async_omni.AsyncOmni][] +- [vllm_omni.entrypoints.async_omni_diffusion.AsyncOmniDiffusion][] +- [vllm_omni.entrypoints.async_omni_llm.AsyncOmniLLM][] +- [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalContentParser][] +- [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalItemTracker][] +- [vllm_omni.entrypoints.chat_utils.parse_chat_messages_futures][] +- [vllm_omni.entrypoints.cli.serve.OmniServeCommand][] +- [vllm_omni.entrypoints.client_request_state.ClientRequestState][] +- [vllm_omni.entrypoints.log_utils.OrchestratorMetrics][] +- [vllm_omni.entrypoints.log_utils.StageRequestMetrics][] +- [vllm_omni.entrypoints.log_utils.StageStats][] +- [vllm_omni.entrypoints.omni.Omni][] +- [vllm_omni.entrypoints.omni.OmniBase][] +- [vllm_omni.entrypoints.omni_diffusion.OmniDiffusion][] +- [vllm_omni.entrypoints.omni_llm.OmniLLM][] +- [vllm_omni.entrypoints.omni_stage.OmniStage][] +- [vllm_omni.entrypoints.stage_utils.OmniStageTaskType][] + +## Inputs + +Input data structures for multi-modal inputs. + +- [vllm_omni.inputs.data.OmniEmbedsPrompt][] +- [vllm_omni.inputs.data.OmniTokenInputs][] +- [vllm_omni.inputs.data.OmniTokensPrompt][] +- [vllm_omni.inputs.parse.parse_singleton_prompt_omni][] +- [vllm_omni.inputs.preprocess.OmniInputPreprocessor][] + +## Engine + +Engine classes for offline and online inference. + +- [vllm_omni.diffusion.diffusion_engine.DiffusionEngine][] +- [vllm_omni.engine.AdditionalInformationEntry][] +- [vllm_omni.engine.AdditionalInformationPayload][] +- [vllm_omni.engine.OmniEngineCoreOutput][] +- [vllm_omni.engine.OmniEngineCoreOutputs][] +- [vllm_omni.engine.OmniEngineCoreRequest][] +- [vllm_omni.engine.PromptEmbedsPayload][] +- [vllm_omni.engine.arg_utils.AsyncOmniEngineArgs][] +- [vllm_omni.engine.arg_utils.OmniEngineArgs][] +- [vllm_omni.engine.input_processor.OmniInputProcessor][] +- [vllm_omni.engine.output_processor.MultimodalOutputProcessor][] +- [vllm_omni.engine.output_processor.OmniRequestState][] + +## Core + +Core scheduling and caching components. + +- [vllm_omni.core.sched.omni_ar_scheduler.KVCacheTransferData][] +- [vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler][] +- [vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler][] +- [vllm_omni.core.sched.output.OmniCachedRequestData][] +- [vllm_omni.core.sched.output.OmniNewRequestData][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.vq.core_vq.DistributedGroupResidualVectorQuantization][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.vq.core_vq.DistributedResidualVectorQuantization][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.vq.core_vq.EuclideanCodebook][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.vq.core_vq.VectorQuantization][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.vq.core_vq.preprocess][] + +## Configuration + +Configuration classes. + +- [vllm_omni.config.model.OmniModelConfig][] +- [vllm_omni.diffusion.cache.teacache.config.TeaCacheConfig][] +- [vllm_omni.distributed.omni_connectors.utils.config.ConnectorSpec][] +- [vllm_omni.distributed.omni_connectors.utils.config.OmniTransferConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts.Qwen3TTSConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts.Qwen3TTSSpeakerEncoderConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts.Qwen3TTSTalkerCodePredictorConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts.Qwen3TTSTalkerConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2.Qwen3TTSTokenizerV2Config][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2.Qwen3TTSTokenizerV2DecoderConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1.Qwen3TTSTokenizerV1Config][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1.Qwen3TTSTokenizerV1DecoderBigVGANConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1.Qwen3TTSTokenizerV1DecoderConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1.Qwen3TTSTokenizerV1DecoderDiTConfig][] +- [vllm_omni.model_executor.models.qwen3_tts.tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1.Qwen3TTSTokenizerV1EncoderConfig][] + +## Workers + +Worker classes and model runners for distributed inference. + +- [vllm_omni.diffusion.worker.gpu_diffusion_model_runner.GPUDiffusionModelRunner][] +- [vllm_omni.diffusion.worker.gpu_diffusion_worker.GPUDiffusionWorker][] +- [vllm_omni.diffusion.worker.gpu_diffusion_worker.WorkerProc][] +- [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorker][] +- [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorkerProc][] +- [vllm_omni.worker.gpu_ar_model_runner.ExecuteModelState][] +- [vllm_omni.worker.gpu_ar_model_runner.GPUARModelRunner][] +- [vllm_omni.worker.gpu_ar_worker.GPUARWorker][] +- [vllm_omni.worker.gpu_generation_model_runner.GPUGenerationModelRunner][] +- [vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker][] +- [vllm_omni.worker.gpu_model_runner.OmniGPUModelRunner][] +- [vllm_omni.worker.npu.npu_ar_model_runner.ExecuteModelState][] +- [vllm_omni.worker.npu.npu_ar_model_runner.NPUARModelRunner][] +- [vllm_omni.worker.npu.npu_ar_worker.NPUARWorker][] +- [vllm_omni.worker.npu.npu_generation_model_runner.NPUGenerationModelRunner][] +- [vllm_omni.worker.npu.npu_generation_worker.NPUGenerationWorker][] +- [vllm_omni.worker.npu.npu_model_runner.OmniNPUModelRunner][] diff --git a/docs/assets/WeChat.jpg b/docs/assets/WeChat.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab4cdd5c1d3de10ac6fd752067cdd64571092c68 Binary files /dev/null and b/docs/assets/WeChat.jpg differ diff --git a/docs/cli/README.md b/docs/cli/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1fcfdb14eaca50402432a652c3f3c8b1eed9204b --- /dev/null +++ b/docs/cli/README.md @@ -0,0 +1,42 @@ +# vLLM-Omni CLI Guide + +The CLI for vLLM-Omni inherits from vllm with some additional arguments. + +## serve + +Starts the vLLM-Omni OpenAI Compatible API server. + +Start with a model: + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni +``` + +Specify the port: + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 +``` + +If you have custom stage configs file, launch the server with command below +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --stage-configs-path /path/to/stage_configs_file +``` + + +## bench + +Run benchmark tests for online serving throughput. +Available Commands: + +```bash +vllm bench serve --omni \ + --model Qwen/Qwen2.5-Omni-7B \ + --host server-host \ + --port server-port \ + --random-input-len 32 \ + --random-output-len 4 \ + --num-prompts 5 +``` + +See [vllm bench serve](./bench/serve.md) for the full reference of all available arguments. diff --git a/docs/cli/bench/serve.md b/docs/cli/bench/serve.md new file mode 100644 index 0000000000000000000000000000000000000000..cc47bfc3cb9a37fc4b77e489a8c5a6bd91b50514 --- /dev/null +++ b/docs/cli/bench/serve.md @@ -0,0 +1,359 @@ +# vLLM-Omni Benchmark CLI Guide +The vllm bench command launches the vLLM-Omni benchmark to evaluate the performance of multimodal models. + +## Notes +We currently only support using the "openai-chat-omni" backend. + +## Basic Parameter Description +You can use `vllm bench serve --omni --help=all` to get descriptions of all parameters. The commonly used parameters are described below: +- `--omni` + Enable Omni (multimodal) mode, supporting multimodal inputs and outputs such as images, videos, and audio. + +- `--backend` + Specify the backend adapter as openai-chat-omni, using OpenAI Chat compatible API behavior as the protocol. Currently only openai-chat-omni is supported. + +- `--model` + The model identifier to load, filled according to the models supported by vLLM-Omni. + +- `--endpoint` + The API endpoint exposed externally, to which clients send their requests. + +- `--dataset-name` + The name of the dataset used; random-mm indicates generating random multimodal inputs (images, videos, audio). + +- `--num-prompts` + The total number of requests to send, an integer. + +- `--max-concurrency` + "Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up." + +- `--request-rate` + "Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times." + +- `--ignore-eos` + "Set ignore_eos flag when sending the benchmark request." + +- `--metric-percentiles` + Comma-separated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\"." + "Use \"--percentile-metrics\" to select metrics. + +- `--percentile-metrics` + "Comma-separated list of selected metrics to report percentiles." + "This argument specifies the metrics to report percentiles." + 'Allowed metric names are "ttft", "tpot", "itl", "e2el", "audio_ttfp", "audio_rtf". ' + +- `--save-result` +Specify to save benchmark results to a json file + +- `--save-detailed` +"When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc." + +- `--result-dir` + "Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory." + +- `--result-filename` +"Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{label}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + +- `--random-prefix-len` + Number of fixed prefix tokens before the random context in a request. + The total input length is the sum of random-prefix-len and a random + context length sampled from [input_len * (1 - range_ratio), + input_len * (1 + range_ratio)].Only the random and random-mm modes + support this parameter. + +- `--random-input-len` + Number of input tokens per request.Only the random and random-mm modes support this parameter. + +- `--random-output-len` + Number of output tokens per request.Only the random and random-mm modes support this parameter. + +- `--random-range-ratio` + Range ratio for sampling input/output length, + used only for random sampling. Must be in the range [0, 1) to define + a symmetric sampling range + [length * (1 - range_ratio), length * (1 + range_ratio)]. + Only the random and random-mm modes support this parameter. + +- `--random-mm-base-items-per-request` + Base number of multimodal items per request for random-mm. + Actual per-request count is sampled around this base using + --random-mm-num-mm-items-range-ratio. + Only the random-mm mode supports this parameter. + +- `--random-mm-limit-mm-per-prompt` + Per-modality hard caps for items attached per request, e.g. + '{"image": 3, "video": 1, "audio": 1}'. The sampled per-request item + count is clamped to the sum of these limits. When a modality + reaches its cap, its buckets are excluded and probabilities are + renormalized. + Only the random-mm mode supports this parameter. + +- `--random-mm-num-mm-items-range-ratio` + Range ratio r in [0, 1] for sampling items per request. + We sample uniformly from the closed integer range + [floor(n*(1-r)), ceil(n*(1+r))] + where n is the base items per request. + r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped + to the sum of per-modality limits from + --random-mm-limit-mm-per-prompt. + An error is raised if the computed min exceeds the max. + Only the random-mm mode supports this parameter. + +- `--random-mm-bucket-config` + The bucket config is a dictionary mapping a multimodal item + sampling configuration to a probability. + Currently allows for 3 modalities: audio, images and videos. + A bucket key is a tuple of (height, width, num_frames) + The value is the probability of sampling that specific item. + Example: + --random-mm-bucket-config + "{(256, 256, 1): 0.5, (720, 1280, 16): 0.4, (0, 1, 5): 0.10}" + First item: images with resolution 256x256 w.p. 0.5 + Second item: videos with resolution 720x1280 and 16 frames + Third item: audios with 1s duration and 5 channels w.p. 0.1 + OBS.: If the probabilities do not sum to 1, they are normalized. + Only the random-mm mode supports this parameter + +## Usage Examples + +### Online Benchmark +
+Show more + +First start serving your model: + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni +``` + +Then run the benchmarking for sharegpt: + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench serve \ + --omni \ + --port 43845 \ + --model /home/models/Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --endpoint /v1/chat/completions \ + --backend openai-chat-omni \ + --num-prompts 2 \ + --dataset-name sharegpt \ + --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ + --percentile-metrics ttft,tpot,itl,e2el +``` +If successful, you will see the following output: +```text +============ Serving Benchmark Result ============ +Successful requests: 2 +Failed requests: 0 +Benchmark duration (s): 81.63 +Request throughput (req/s): 0.02 +Peak concurrent requests: 2.00 +----------------End-to-end Latency---------------- +Mean E2EL (ms): 56966.13 +Median E2EL (ms): 56966.13 +P99 E2EL (ms): 81016.80 +================== Text Result =================== +Total input tokens: 36 +Total generated tokens: 5926 +Output token throughput (tok/s): 72.60 +Peak output token throughput (tok/s): 103.00 +Peak concurrent requests: 2.00 +Total Token throughput (tok/s): 73.04 +---------------Time to First Token---------------- +Mean TTFT (ms): 124.76 +Median TTFT (ms): 124.76 +P99 TTFT (ms): 156.10 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 481.30 +Median TPOT (ms): 481.30 +P99 TPOT (ms): 947.55 +---------------Inter-token Latency---------------- +Mean ITL (ms): 25.11 +Median ITL (ms): 0.33 +P99 ITL (ms): 25.17 +================== Audio Result ================== +Total audio duration generated(s): 3.95 +Total audio frames generated: 94890 +Audio throughput(audio duration/s): 0.05 +================================================== +``` + +Or run the benchmarking for random: + +```bash +vllm bench serve \ + --omni \ + --port 43845 \ + --endpoint /v1/chat/completions \ + --backend openai-chat-omni \ + --model /home/models/Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --dataset-name random \ + --num-prompts 2 \ + --random-prefix-len 5 \ + --random-input-len 10 \ + --random-output-len 100 \ + --percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf \ + --ignore-eos +``` + +If successful, you will see the following output: + +```text +============ Serving Benchmark Result ============ +Successful requests: 2 +Failed requests: 0 +Benchmark duration (s): 24.35 +Request throughput (req/s): 0.08 +Peak concurrent requests: 2.00 +----------------End-to-end Latency---------------- +Mean E2EL (ms): 22576.23 +Median E2EL (ms): 22576.23 +P99 E2EL (ms): 24205.72 +================== Text Result =================== +Total input tokens: 30 +Total generated tokens: 8973 +Output token throughput (tok/s): 368.52 +Peak output token throughput (tok/s): 81.00 +Peak concurrent requests: 2.00 +Total Token throughput (tok/s): 369.76 +---------------Time to First Token---------------- +Mean TTFT (ms): 125.16 +Median TTFT (ms): 125.16 +P99 TTFT (ms): 155.88 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 5.01 +Median TPOT (ms): 5.01 +P99 TPOT (ms): 5.42 +---------------Inter-token Latency---------------- +Mean ITL (ms): 34.15 +Median ITL (ms): 0.01 +P99 ITL (ms): 376.19 +================== Audio Result ================== +Total audio duration generated(s): 3.95 +Total audio frames generated: 94890 +Audio throughput(audio duration/s): 0.16 +---------------Time to First Packet--------------- +Mean AUDIO_TTFP (ms): 11756.89 +Median AUDIO_TTFP (ms): 11756.89 +P99 AUDIO_TTFP (ms): 20854.25 +-----------------Real Time Factor----------------- +Mean AUDIO_RTF: 3.75 +Median AUDIO_RTF: 3.75 +P99 AUDIO_RTF: 7.39 +================================================== +``` +Notes: +We use (audio generation time - first packet latency) / audio duration to calculate RTF. + +
+ +### Multi-Modal Benchmark + +
+Show more + +Benchmark the performance of multi-modal requests in vLLM-Omni. + +Generate synthetic image、video、audio inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat-omni`) and endpoint `/v1/chat/completions`. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni +``` + +It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Then run the benchmarking script: +```bash +vllm bench serve \ + --omni \ + --dataset-name random-mm \ + --port 40849 \ + --model /home/models/Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --endpoint /v1/chat/completions \ + --backend openai-chat-omni \ + --request-rate 1 \ + --num-prompts 1 \ + --random-input-len 10 \ + --random-range-ratio 0.0 \ + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0 \ + --random-mm-limit-mm-per-prompt '{"image":1,"video":1, "audio": 1}' \ + --random-mm-bucket-config '{"(32, 32, 1)": 0.5, "(0, 1, 1)": 0.1, "(32, 32, 2)":0.4}' \ + --ignore-eos \ + --percentile-metrics ttft,tpot,itl \ + --random-output-len 2 \ + --extra_body '{"modalities": ["text"]}' +``` + +If successful, you will see the following output: + +```text +============ Serving Benchmark Result ============ +Successful requests: 1 +Failed requests: 0 +Request rate configured (RPS): 1.00 +Benchmark duration (s): 1.21 +Request throughput (req/s): 0.83 +Peak concurrent requests: 1.00 +================== Text Result =================== +Total input tokens: 10 +Total generated tokens: 3 +Output token throughput (tok/s): 2.49 +Peak output token throughput (tok/s): 3.00 +Peak concurrent requests: 1.00 +Total Token throughput (tok/s): 10.77 +---------------Time to First Token---------------- +Mean TTFT (ms): 179.74 +Median TTFT (ms): 179.74 +P99 TTFT (ms): 179.74 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 12.76 +Median TPOT (ms): 12.76 +P99 TPOT (ms): 12.76 +---------------Inter-token Latency---------------- +Mean ITL (ms): 12.76 +Median ITL (ms): 12.76 +P99 ITL (ms): 25.24 +================== Audio Result ================== +Total audio duration generated(s): 0.00 +Total audio frames generated: 0 +Audio throughput(audio duration/s): 0.00 +================================================== +``` + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. +
diff --git a/docs/community/contact_us.md b/docs/community/contact_us.md new file mode 100644 index 0000000000000000000000000000000000000000..09c7815a038834a9a09ee13d947a2e0e25bb79c6 --- /dev/null +++ b/docs/community/contact_us.md @@ -0,0 +1,5 @@ +# Contact Us + +- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm-omni/issues) +- For coordinating contributions and development and discussing with other users and developers, please join `sig-omni` channel in our [Slack](https://slack.vllm.ai/) or use the [vLLM Forum](https://discuss.vllm.ai/) +- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm-omni/security/advisories) feature diff --git a/docs/community/meetups.md b/docs/community/meetups.md new file mode 100644 index 0000000000000000000000000000000000000000..3374fe711cfdf2d4c1d6257f9857b0d25e4d47c4 --- /dev/null +++ b/docs/community/meetups.md @@ -0,0 +1 @@ +# Meetups diff --git a/docs/community/volunteers.md b/docs/community/volunteers.md new file mode 100644 index 0000000000000000000000000000000000000000..2c25485ea9041ca0944c01a09721266292f40621 --- /dev/null +++ b/docs/community/volunteers.md @@ -0,0 +1,12 @@ +# Volunteers for Bugfix and CI + +We encourage you to check current docs and [issues](https://github.com/vllm-project/vllm-omni/issues) to find possible solutions for your questions. If non of these can solve it, please propose an issue to describe your questions about bug or CI problems for developing. + +If you have urgent need for locating and solving bugfix or CI problems, please find community volunteers below. + +| Dec 4-Dec 12 | Dec 15-Dec 19 | Dec 22-Dec 26 | Dec 29- Jan 2, 2026| Jan 5-Jan 9 | Jan 12-Jan 16 | +|----------|----------|----------|----------|----------|----------| +| Conw729 | yinpeiqi | tzhouam | SamitHuang | gcanlin | natureofnature | +| david6666666 | R2-Y | hsliuustc0106 | Gaohan123 | ZJY0516 | qibaoyuan | + +We kindly welcome more contributors to fix bugs and contribute new features! diff --git a/docs/configuration/README.md b/docs/configuration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..02440b95dcef834ef9c222f92c88aaea31b0ad7e --- /dev/null +++ b/docs/configuration/README.md @@ -0,0 +1,21 @@ +# Configuration Options + +This section lists the most common options for running vLLM-Omni. + +For options within a vLLM Engine. Please refer to [vLLM Configuration](https://docs.vllm.ai/en/v0.14.0/configuration/index.html) + +Currently, the main options are maintained by stage configs for each model. + +For specific example, please refer to [Qwen2.5-omni stage config](stage_configs/qwen2_5_omni.yaml) + +For introduction, please check [Introduction for stage config](./stage_configs.md) + +## Memory Configuration + +- **[GPU Memory Calculation and Configuration](./gpu_memory_utilization.md)** - Guide on how to calculate memory requirements and set up `gpu_memory_utilization` for optimal performance + +## Optimization Features + +- **[TeaCache Configuration](../user_guide/diffusion/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss +- **[Cache-DiT Configuration](../user_guide/diffusion/cache_dit_acceleration.md)** - Enable Cache-DiT as cache acceleration backends for DiT models +- **[Parallelism Configuration](../user_guide/diffusion/parallelism_acceleration.md)** - Enable parallelism (e.g., sequence parallelism) for for DiT models diff --git a/docs/configuration/gpu_memory_utilization.md b/docs/configuration/gpu_memory_utilization.md new file mode 100644 index 0000000000000000000000000000000000000000..19fc042aa52e4d1d105376c438f20694113cb7b5 --- /dev/null +++ b/docs/configuration/gpu_memory_utilization.md @@ -0,0 +1,207 @@ +# GPU Memory Calculation and Configuration + +This guide explains how to calculate GPU memory requirements and properly configure `gpu_memory_utilization` for vLLM-Omni stages. + +## Overview + +`gpu_memory_utilization` is a critical parameter that controls how much GPU memory each stage can use. It's specified as a fraction between 0.0 and 1.0, where: +- `0.8` means 80% of the GPU's total memory +- `1.0` means 100% of the GPU's total memory (not recommended, leaves no buffer) + +## How Memory is Calculated + +### Memory Allocation Formula + +For each stage, vLLM-Omni calculates the requested memory as: + +``` +requested_memory = total_gpu_memory × gpu_memory_utilization +``` + +The system checks that: +``` +free_memory ≥ requested_memory +``` + +If this condition is not met, the stage will fail to initialize with an error message showing the memory requirements. + +### Memory Components + +The total memory used by a stage includes: + +1. **Model Weights**: The size of the model parameters loaded on the GPU +2. **KV Cache**: Memory for storing key-value cache during generation +3. **Activation Memory**: Temporary memory for intermediate computations +4. **System Overhead**: Memory used by CUDA, PyTorch, and other system components +5. **Non-Torch Memory**: Memory allocated outside of PyTorch (e.g., CUDA graphs) + +### Example Calculation + +For a GPU with 80GB total memory: +- `gpu_memory_utilization: 0.8` → 64GB available for the stage +- `gpu_memory_utilization: 0.6` → 48GB available for the stage +- `gpu_memory_utilization: 0.15` → 12GB available for the stage + +## Setting Up `gpu_memory_utilization` + +### Step 1: Determine GPU Memory + +First, check your GPU's total memory: + +```bash +# Using nvidia-smi +nvidia-smi --query-gpu=memory.total --format=csv + +# Or using Python +python -c "import torch; print(f'{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')" +``` + +### Step 2: Estimate Model Memory Requirements + +#### For Autoregressive (AR) Stages + +AR stages typically need more memory due to: +- Large model weights +- KV cache for attention +- Activation buffers + +#### For Diffusion/Generation Stages + +Diffusion stages (like code2wav) typically need less memory: +- Smaller model components +- Different memory access patterns + +**Typical values:** +- `0.1 - 0.3` for most diffusion stages + +### Step 3: Consider Multi-Stage Scenarios + +When multiple stages share the same GPU, you must ensure the sum of their `gpu_memory_utilization` values doesn't exceed 1.0. + +**Example: Two stages on GPU 0** +```yaml +stage_args: + - stage_id: 0 + runtime: + devices: "0" + engine_args: + gpu_memory_utilization: 0.6 # Uses 60% of GPU 0 + + - stage_id: 1 + runtime: + devices: "0" + engine_args: + gpu_memory_utilization: 0.3 # Uses 30% of GPU 0 + # Total: 90% of GPU 0 (safe, leaves 10% buffer) +``` + +**Important:** If stages run on different GPUs, each can use up to 1.0 independently. + +### Step 4: Account for Tensor Parallelism + +When using `tensor_parallel_size > 1`, the model is split across multiple GPUs, so each GPU needs less memory. + +**Example: 2-way tensor parallelism** +```yaml +stage_args: + - stage_id: 0 + runtime: + devices: "0,1" # Uses both GPUs + engine_args: + tensor_parallel_size: 2 + gpu_memory_utilization: 0.6 # 60% per GPU + # Model is split, so each GPU uses ~30% of model memory +``` + +## Examples + +### Qwen3-Omni-MoE on 2x H100-80GB + +```yaml +stage_args: + - stage_id: 0 # Thinker stage with TP=2 + runtime: + devices: "0,1" + engine_args: + tensor_parallel_size: 2 + gpu_memory_utilization: 0.6 # 48GB per GPU + + - stage_id: 1 # Talker stage + runtime: + devices: "1" + engine_args: + gpu_memory_utilization: 0.3 # 24GB on GPU 1 + + - stage_id: 2 # Code2Wav stage + runtime: + devices: "0" + engine_args: + gpu_memory_utilization: 0.1 # 8GB on GPU 0 +``` +**Note:** In this configuration, stages 0 and 2 share GPU 0, but they run at different times in the pipeline, so their memory usage doesn't overlap. + +## Troubleshooting + +### Error: "Free memory is less than desired GPU memory utilization" + +This means the GPU doesn't have enough free memory when the stage starts. + +**Solutions:** +1. Free up memory by closing other processes +2. Reduce `gpu_memory_utilization` for this stage +3. Use a GPU with more memory +4. Move the stage to a different GPU + +### Error: OOM during inference + +The stage initialized but ran out of memory during processing. + +**Solutions:** +1. Reduce `max_num_batched_tokens` +2. Reduce `max_batch_size` in runtime config +3. Lower `gpu_memory_utilization` slightly +4. Enable quantization if supported + +### Memory Not Fully Utilized + +If you see low memory usage, you can: +1. Increase `gpu_memory_utilization` to allow larger KV cache +2. Increase `max_num_batched_tokens` for better batching +3. Check if other stages are limiting throughput + +## Useful formula for Memory Calculation + +### KV Cache Memory + +The KV cache size depends on: +- Number of sequences in batch +- Sequence length (prompt + generation) +- Model hidden size +- Number of attention heads +- Number of layers + +approximate Formula: +``` +kv_cache_memory ≈ batch_size × seq_len × hidden_size × num_layers × 2 × dtype_size +``` +2 for k & v + +### Model Weight Memory + +``` +model_memory ≈ num_parameters × dtype_size +``` + +For example: +- 7B parameters in FP16: ~14GB +- 7B parameters in FP32: ~28GB +- 7B parameters in INT8: ~7GB + +### Activation Memory + +Activation memory is typically smaller but varies with: +- Batch size +- Sequence length +- Model architecture + +It's usually 10-30% of model weight memory during inference. diff --git a/docs/configuration/stage_configs.md b/docs/configuration/stage_configs.md new file mode 100644 index 0000000000000000000000000000000000000000..1f1a1890251b8f242da282c9c703823fc24e900b --- /dev/null +++ b/docs/configuration/stage_configs.md @@ -0,0 +1,275 @@ +# Stage configs for vLLM-Omni + +In vLLM-Omni, the target model is separated into multiple stages, which are processed by different LLMEngines, DiffusionEngines or other types of engines. Depending on different types of stages, such as Autoregressive (AR) stage or Diffusion transformer (DiT) stage, each can choose corresponding schedulers, model workers to load with the Engines in a plug-in fashion. + +!!! note + Default stage config YAMLs (for example, `vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml` and `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`) are bundled and loaded automatically when `stage_configs_path` is not provided. They have been verified to work on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni. + +Therefore, as a core part of vLLM-Omni, the stage configs for a model have several main functions: + +- Claim partition of stages and their corresponding class implementation in `model_executor/models`. +- The disaggregated configuration for each stage and the communication topology among them. +- Engine arguments for each engine within the stage. +- Input and output dependencies for each stage. +- Default input parameters. + +If users want to modify some part of it. The custom stage_configs file can be input as input argument in both online and offline. Just like examples below: + +For offline (Assume necessary dependencies have ben imported): +```python +model_name = "Qwen/Qwen2.5-Omni-7B" +omni_llm = OmniLLM(model=model_name, stage_configs_path="/path/to/custom_stage_configs.yaml") +``` + +For online serving: +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` +!!! important + We are actively iterating on the definition of stage configs, and we welcome all feedbacks from both community users and developers to help us shape the development! + +Below is a specific example of stage_configs.yaml in Qwen2.5-omni. +```python +# stage config for running qwen2.5-omni with architecture of OmniLLM. +stage_args: + - stage_id: 0 # mark the unique id for each stage + runtime: # The disaggregated configuration + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + max_batch_size: 1 # the batch_size for offline inference + engine_args: # Engine arguments for a certain engine + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration # The model implementation registered in model_executor/models/registry.py + worker_type: ar # The specific worker used + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler # The specific scehduler used + gpu_memory_utilization: 0.8 # The gpu memory allocation for the stage within a single chip + enforce_eager: true # Now we only support eager mode + trust_remote_code: true # Needed by huggingface config parsing + engine_output_type: latent # It claims that the stage will input latent hiddenstates besides token ids + enable_prefix_caching: false # For request with hiddenstates output, the prefix caching is not supported now + is_comprehension: true # If the stage is a text or multimodal comprehension module. If it is, the AsyncOmni will use its tokenizer as default + final_output: true # If the stage has output as part of final outputs. If it is false, which means that the stage only works as a intermediate role. + final_output_type: text # What is the final output type. It can be text and audio now. + default_sampling_params: # sampling parameters for the stage. Their meaning aligns with vLLM. + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + runtime: + process: true + devices: "1" + max_batch_size: 3 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + - stage_id: 2 + runtime: + process: true + devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 + +``` + +## Stage Configuration Arguments + +Each stage in the `stage_args` list contains the following configuration options: + +### `stage_id` + +A unique identifier for each stage in the multi-stage pipeline. Stages are numbered sequentially starting from 0, and this ID is used to reference stages in inter-stage dependencies (e.g., `engine_input_source`). + +### `runtime` + +Configuration for disaggregated execution of the stage, controlling how the stage is deployed and executed. + +#### `runtime.process` + +Whether to run this stage in a separate process. When set to `true`, the stage will be executed in an isolated process, enabling better resource isolation and parallel execution across different stages. This is essential for multi-GPU deployments where different stages run on different devices. + +Default: `true` + +#### `runtime.devices` + +Visible devices for this stage, specified as a string. This controls which GPU devices are available to the stage process, similar to setting `CUDA_VISIBLE_DEVICES` or using `torch.cuda.set_device()`. For example, `"0"` uses GPU 0, `"1"` uses GPU 1, and `"0,1"` makes both GPUs 0 and 1 visible. + +Default: `"0"` + +#### `runtime.max_batch_size` + +The maximum batch size for offline inference in this stage. This limits how many sequences can be processed together in a single batch during offline inference operations. + +Default: `1` + +### `engine_args` + +Engine arguments for configuring the LLM engine, diffusion engine, or other engine types used by this stage. + +#### `engine_args.model_stage` + +The name identifier for this model stage within the multi-stage architecture. This is used internally to distinguish different stages of the same model (e.g., "thinker", "talker", "code2wav" in Qwen2.5-Omni). + +#### `engine_args.model_arch` + +The model architecture class name that is registered in `model_executor/models/registry.py`. This specifies which model implementation to use for this stage. The class must be registered in the model registry for vLLM-Omni to locate and instantiate it. + +#### `engine_args.worker_cls` + +The specific worker class to use for this stage. This determines how the model computations are executed. Examples include `vllm_omni.worker.gpu_ar_worker.GPUARWorker` for autoregressive stages and `vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker` for diffusion-based stages. + +#### `engine_args.scheduler_cls` + +The scheduler class to use for this stage. The scheduler manages request queuing, batching, and execution order. Examples include `vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler` for standard stages and `vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler` for diffusion stages. + +#### `engine_args.gpu_memory_utilization` + +The fraction of GPU memory to allocate for this stage within a single GPU chip. This is a value between 0.0 and 1.0, where 0.8 means 80% of the GPU memory will be used by this stage. This allows fine-grained control over memory allocation when multiple stages share the same GPU or when reserving memory for other operations. + +Default: `0.8` + +!!! tip "Memory Configuration Guide" + For detailed information on how to calculate memory requirements and properly configure `gpu_memory_utilization`, see the [GPU Memory Calculation and Configuration Guide](./gpu_memory_utilization.md). + +#### `engine_args.enforce_eager` + +Whether to enforce eager execution mode. When set to `true`, the engine will run in eager mode without using CUDA graphs or other compilation optimizations. Currently, vLLM-Omni only supports eager mode. + +Default: `true` + +#### `engine_args.trust_remote_code` + +Whether to trust remote code when loading models from Hugging Face. This is required for models that use custom code in their configuration files. Set to `true` when loading models that require custom model implementations. + +Default: `true` + +#### `engine_args.engine_output_type` + +Specifies the type of output produced by this stage's engine. This determines what kind of data flows to downstream stages. Possible values include `latent` (hidden states), `text` (tokenized text), and `audio` (audio waveforms). When set to `latent`, the stage outputs latent hidden states in addition to token IDs, which are consumed by downstream stages. + +Default: `latent` + +#### `engine_args.enable_prefix_caching` + +Whether to enable prefix caching for this stage. Prefix caching can improve performance by caching KV cache for common prompt prefixes. However, for requests that output hidden states (when `engine_output_type` is `latent`), prefix caching is not currently supported and should be set to `false`. + +Default: `false` + +### `is_comprehension` + +Whether this stage is a text or multimodal comprehension module. When set to `true`, the stage acts as a comprehension module that processes input text or multimodal content. If this is the first comprehension stage, `AsyncOmni` will use its tokenizer as the default tokenizer for the entire pipeline. + +Default: `true` + +### `final_output` + +Whether this stage produces output that is part of the final outputs returned to the user. When set to `false`, the stage only works as an intermediate stage, processing data that flows to downstream stages but not contributing directly to the final response. + +Default: `true` + +### `final_output_type` + +The type of final output produced by this stage. This specifies what format the output will be in when returned to the user. Currently supported values are `text` (for text generation) and `audio` (for audio generation). + +Default: `text` + +### `default_sampling_params` + +Default sampling parameters for this stage. These parameters control the generation behavior and align with vLLM's sampling parameter semantics. These defaults are used when no explicit sampling parameters are provided in the request. + +#### `default_sampling_params.temperature` + +Sampling temperature for controlling randomness. Lower values (e.g., 0.0) make the output more deterministic and focused, while higher values increase randomness. + +Default: `0.0` + +#### `default_sampling_params.top_p` + +Nucleus sampling parameter. Only tokens with cumulative probability mass up to `top_p` are considered. This helps filter out low-probability tokens while maintaining diversity. + +Default: `1.0` + +#### `default_sampling_params.top_k` + +Top-k sampling parameter. Only the top `k` most likely tokens are considered. Set to `-1` to disable top-k filtering and consider all tokens. + +Default: `-1` + +#### `default_sampling_params.max_tokens` + +Maximum number of tokens to generate in this stage. This limits the length of the output sequence. + +Default: `2048` + +#### `default_sampling_params.seed` + +Random seed for reproducible generation. When set, the random number generator will be initialized with this seed to ensure consistent outputs across runs. + +Default: `42` + +#### `default_sampling_params.detokenize` + +Whether to detokenize the output tokens into text. When set to `true`, token IDs are converted back to readable text strings. + +Default: `True` + +#### `default_sampling_params.repetition_penalty` + +Penalty applied to tokens that have already appeared in the generated sequence. Values greater than 1.0 discourage repetition, while values less than 1.0 encourage it. A value of 1.0 applies no penalty. + +Default: `1.1` diff --git a/docs/configuration/stage_configs/qwen2_5_omni.yaml b/docs/configuration/stage_configs/qwen2_5_omni.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e20e79c3e91e23692691fc259f872c41e7d2ba96 --- /dev/null +++ b/docs/configuration/stage_configs/qwen2_5_omni.yaml @@ -0,0 +1,94 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. +stage_args: + - stage_id: 0 + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + - stage_id: 2 + runtime: + process: true + devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/docs/contributing/DOCS_GUIDE.md b/docs/contributing/DOCS_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..100bac67423a10d72f9a7e2b6a1733f5dedc684c --- /dev/null +++ b/docs/contributing/DOCS_GUIDE.md @@ -0,0 +1,139 @@ +# Documentation Build Guide + +This directory contains the source files for the vLLM-Omni documentation. + +## Building Documentation Locally + +### Prerequisites + +Install documentation dependencies: + +```bash +uv pip install -e ".[docs]" +``` + +### Build and Serve Documentation + +From the project root: + +```bash +# Serve documentation locally (auto-reload on changes) +# This starts a local web server at http://127.0.0.1:8000 +mkdocs serve + +# Build static site (generates HTML files in site/ directory) +mkdocs build +``` + +When using `mkdocs serve`, the documentation will be automatically available at `http://127.0.0.1:8000`. The server will automatically reload when you make changes to the documentation files. + +## Auto-generating API Documentation + +The documentation automatically extracts docstrings from the code using mkdocstrings. To ensure your code is documented: + +1. Add docstrings to all public classes, functions, and methods +2. Use Google or NumPy style docstrings (both are supported) +3. Rebuild the documentation to see changes + +Example docstring: + +```python +class Omni: + """Main entry point for vLLM-Omni inference. + + This class provides a high-level interface for running multi-modal + inference with non-autoregressive models. + + Args: + model: Model name or path + stage_configs: Optional stage configurations + **kwargs: Additional arguments passed to the engine + + Example: + >>> llm = Omni(model="Qwen/Qwen2.5-Omni") + >>> outputs = llm.generate(prompts="Hello") + """ +``` + +## Documentation Structure + +``` +docs/ +├── index.md # Main documentation page +├── getting_started/ # Getting started guides +├── architecture/ # Architecture documentation +├── api/ # API reference (auto-generated from code) +├── examples/ # Code examples +└── stylesheets/ # Custom CSS +``` + +## Publishing Documentation + +### GitHub Pages (Recommended) + +The documentation is automatically deployed to GitHub Pages using GitHub Actions. + +1. **Enable GitHub Pages**: + - Go to repository `Settings` → `Pages` + - Set `Source` to `GitHub Actions` + - Save settings + +2. **Push changes**: + ```bash + git push origin main + ``` + +3. **Documentation will be available at**: + - `https://vllm-omni.readthedocs.io` + +The GitHub Actions workflow (`.github/workflows/docs.yml`) will automatically: +- Build the documentation when you push to `main` branch +- Deploy it to GitHub Pages +- Update the documentation whenever you make changes + + +### Read the Docs (Alternative) + +You can also use Read the Docs for hosting: + +1. Sign up at https://readthedocs.org/ +2. Import the `vllm-project/vllm-omni` repository +3. Read the Docs will automatically build using `.readthedocs.yml` +4. Documentation will be available at: `https://vllm-omni.readthedocs.io/` + +## Configuration + +The documentation configuration is in `mkdocs.yml` at the project root. + +## Tips + +- **API Documentation**: API docs are automatically generated using `mkdocs-api-autonav` and `mkdocstrings` + - No need to manually create API pages - they're generated automatically + - Use `[module.name.ClassName][]` syntax for cross-references in Summary pages +- **Code Snippets**: Use `--8<-- "path/to/file.py"` for including code snippets +- **Markdown**: Use Markdown for all documentation (no need for RST) +- **Material Theme**: Use Material theme features like: + - Admonitions: `!!! note`, `!!! warning`, etc. + - Code blocks with syntax highlighting + - Tabs for organizing content + - Math formulas using `pymdownx.arithmatex` + +## Troubleshooting + +### Documentation not updating + +- Make sure you've saved all files +- If using `mkdocs serve`, it should auto-reload +- Check for syntax errors in `mkdocs.yml` + +### API links not working + +- Ensure class names match exactly (case-sensitive) +- Check that the module is imported correctly +- Run `mkdocs build --strict` to check for errors + +### Build errors + +- Check Python version (requires 3.9+) +- Ensure all dependencies are installed: `pip install -e ".[docs]"` +- Check `mkdocs.yml` syntax with `mkdocs build --strict` diff --git a/docs/contributing/README.md b/docs/contributing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..29a02dc416a6b1a20d65e9d9c246ed7fc8b25aaf --- /dev/null +++ b/docs/contributing/README.md @@ -0,0 +1,155 @@ +# Contributing to vLLM-Omni + +Thank you for your interest in contributing to vLLM-Omni! This document provides guidelines and instructions for contributing. + +!!! note + We host weekly developer-facing online meetings to discuss milestones and updates **every Tuesday at 19:30 PDT**. Meeting link as well as the past meeting notes can be found [here](https://tinyurl.com/vllm-omni-meeting). + +## Getting Started + +vLLM-Omni uses `uv` as the environment manager, to create and manage Python environments. Please follow the documentation to install `uv`. After installing `uv`, you can create a new Python environment using the following commands: + +```bash +uv venv --python 3.12 --seed +source .venv/bin/activate +``` + +### Development Environment for vLLM and vLLM-Omni + +vLLM-Omni is quickly evolving, please see the [installation guide](../getting_started/installation/README.md) for details. It's recommended to build from source to provide the latest development environment. + +!!! tip + vLLM-Omni is compatible with Python versions 3.10 to 3.12. However, we recommend developing with Python 3.12 to minimize the chance of your local environment clashing with our CI environment. + +### Adding a new model to vLLM-Omni + +Please check [model implementation](model/README.md) for how to add diffusion and omni-modality models to vLLM-Omni. + +### Linting + +vLLM-Omni uses `pre-commit` to lint and format the codebase. See [pre-commit documentation](https://pre-commit.com/#usage) if `pre-commit` is new to you. Setting up `pre-commit` is as easy as: + +```bash +uv pip install pre-commit +pre-commit install +``` + +vLLM-Omni's `pre-commit` hooks will now run automatically every time you commit. + +!!! tip + You can manually run the `pre-commit` hooks using: + + ```bash + pre-commit run # runs on staged files + pre-commit run --show-diff-on-failure --color=always --all-files # runs on all files (short for --all-files) + ``` + +### Documentation + +MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, `mkdocs.yml`. + +Get started with: + +```bash +uv pip install -e ".[docs]" +``` + +MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. From the root of the repository, run: + +```bash +mkdocs serve # with API ref (~10 minutes) +API_AUTONAV_EXCLUDE=vllm_omni mkdocs serve # API ref off (~15 seconds) +``` + +Once you see `Serving on http://127.0.0.1:8000/` in the logs, the live preview is ready! Open in your browser to see it. + +For additional features and advanced configurations, refer to the: + +- [MkDocs documentation](https://www.mkdocs.org/) +- [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/) (the MkDocs theme we use) + +### Testing + +vLLM-Omni uses `pytest` to test the codebase. + +```bash +# Run all tests +pytest tests/ + +# Run tests for a single test file with detailed output +pytest -s -v tests/test_omni_llm.py +``` + +!!! warning + Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU platform to run unit tests locally, rely on the continuous integration system to run the tests for now. + +## Issues + +If you encounter a bug or have a feature request, please search existing issues first to see if it has already been reported. If not, please file a new issue, providing as much relevant information as possible. + +!!! important + If you discover a security vulnerability, please report it by creating a GitHub issue with the `security` label. + +## Pull Requests & Code Reviews + +Thank you for your contribution to vLLM-Omni! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM-Omni maintain the code quality and improve the efficiency of the review process. + +### DCO and Signed-off-by + +When contributing changes to this project, you must agree to the [DCO](https://developercertificate.org/). Commits must include a `Signed-off-by:` header which certifies agreement with the terms of the DCO. + +Using `-s` with `git commit` will automatically add this header. + +!!! tip + You can enable automatic sign-off via your IDE: + + - **PyCharm**: Click on the `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window. It will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`. + - **VSCode**: Open the Settings editor and enable the `Git: Always Sign Off` (`git.alwaysSignOff`) field. + +### PR Title and Classification + +Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following: + +- `[Bugfix]` for bug fixes. +- `[CI/Build]` for build or continuous integration improvements. +- `[Doc]` for documentation fixes and improvements. +- `[Model]` for adding a new model or improving an existing model. Model name should appear in the title. +- `[Frontend]` For changes on the vLLM-Omni frontend (e.g., OpenAI API server, `OmniLLM` class, etc.) +- `[Kernel]` for changes affecting CUDA kernels or other compute kernels. +- `[Core]` for changes in the core vLLM-Omni logic (e.g., `OmniProcessor`, `OmniARScheduler`, etc.) +- `[Hardware][Vendor]` for hardware-specific changes. Vendor name should appear in the prefix, such as [Ascend] for Ascend NPUs. +- `[Misc]` for PRs that do not fit the above categories. Please use this sparingly. + +!!! note + If the PR spans more than one category, please include all relevant prefixes. + +### Code Quality + +The PR needs to meet the following code quality standards: + +- We adhere to Google Python style guide and Google C++ style guide. +- Pass all linter checks. +- The code needs to be well-documented to ensure future contributors can easily understand the code. +- Include sufficient tests to ensure the project stays correct and robust. This includes both unit tests and integration tests. +- Please add documentation to `docs/` if the PR modifies the user-facing behaviors of vLLM-Omni. It helps vLLM-Omni users understand and utilize the new features or changes. + +### Notes for Large Changes + +Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with `rfc-required` and might not go through the PR. + +### What to Expect for the Reviews + +The goal of the vLLM-Omni team is to be a _transparent reviewing machine_. We would like to make the review process transparent and efficient and make sure no contributor feels confused or frustrated. However, the vLLM-Omni team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process: + +- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability. +- After the PR is assigned, the reviewer will provide status updates every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM-Omni team. +- After the review, the reviewer will put an `action-required` label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR. +- Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion. + +## Additional Resources + +- [Design Documents](../design/index.md) - Architecture and design documentation + +## Thank You + +Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM-Omni. All of your contributions help make vLLM-Omni a great tool and community for everyone! diff --git a/docs/contributing/ci/failures.md b/docs/contributing/ci/failures.md new file mode 100644 index 0000000000000000000000000000000000000000..d64d98a605e052ac57c19bab4984e3d2559267ad --- /dev/null +++ b/docs/contributing/ci/failures.md @@ -0,0 +1,4 @@ +# CI Failures + +What should I do when a CI job fails on my PR, but I don't think my PR caused +the failure? diff --git a/docs/contributing/ci/tests_markers.md b/docs/contributing/ci/tests_markers.md new file mode 100644 index 0000000000000000000000000000000000000000..bf56914f8da844f9fb720e61bf02bd89be2e0a5d --- /dev/null +++ b/docs/contributing/ci/tests_markers.md @@ -0,0 +1,160 @@ +# Markers for Tests + +By adding markers before test functions, tests can later be executed uniformly by simply declaring the corresponding marker type. + +## Current Markers +Defined in `pyproject.toml`: + +| Marker | Description | +| ------------------ | ------------------------------------------------------- | +| `core_model` | Core model tests (run in each PR) | +| `diffusion` | Diffusion model tests | +| `omni` | Omni model tests | +| `cache` | Cache backend tests | +| `parallel` | Parallelism/distributed tests | +| `cpu` | Tests that run on CPU | +| `gpu` | Tests that run on GPU (auto-added) | +| `cuda` | Tests that run on CUDA (auto-added) | +| `rocm` | Tests that run on AMD/ROCm (auto-added) | +| `npu` | Tests that run on NPU/Ascend (auto-added) | +| `H100` | Tests that require H100 GPU | +| `L4` | Tests that require L4 GPU | +| `MI325` | Tests that require MI325 GPU (AMD/ROCm) | +| `A2` | Tests that require A2 NPU | +| `A3` | Tests that require A3 NPU | +| `distributed_cuda` | Tests that require multi cards on CUDA platform | +| `distributed_rocm` | Tests that require multi cards on ROCm platform | +| `distributed_npu` | Tests that require multi cards on NPU platform | +| `skipif_cuda` | Skip if the num of CUDA cards is less than the required | +| `skipif_rocm` | Skip if the num of ROCm cards is less than the required | +| `skipif_npu` | Skip if the num of NPU cards is less than the required | +| `slow` | Slow tests (may skip in quick CI) | +| `benchmark` | Benchmark tests | + +For those markers shown as auto-added, they will be added by the `@hardware_test` decorator. + +### Example usage for markers + +```python +from tests.utils import hardware_test + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards=2, +) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_video_to_audio() + ... +``` +### Decorator: `@hardware_test` + +This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/utils.py` performs the following actions: + +1. **Applies platform and resource markers** + Adds the appropriate pytest markers for each specified hardware platform (e.g., `cuda`, `rocm`, `npu`) and resource type (e.g., `L4`, `H100`, `MI325`, `A2`, `A3`). + ``` + @pytest.mark.cuda + @pytest.mark.L4 + ``` +2. **Handles multi-card (distributed) scenarios** + For tests requiring multiple cards, it automatically adds distributed markers such as `distributed_cuda`, `distributed_rocm`, or `distributed_npu`. + ``` + @pytest.mark.distributed_cuda(num_cards=num_cards) + ``` +3. **Supports flexible card requirements** + Accepts `num_cards` as either a single integer for all platforms or as a dictionary with per-platform values. If not specified, defaults to 1 card per platform. + +4. **Integrates resource validation** + On CUDA, adds a skip marker (`skipif_cuda`) if the system does not have the required number of devices. + Support for `skipif_rocm` and `skipif_npu` will be implemented later. + + +5. **Runs each test in a new process** + Automatically wraps the distributed test with a decorator (`@create_new_process_for_each_test`) to ensure isolation and compatibility with multi-process hardware backends. + +6. **Works with pytest filtering** + Allows tests to be filtered and selected at runtime using standard pytest marker expressions (e.g., `-m "distributed_cuda and L4"`). + +#### Example usage for decorator +- Single call for multiple platforms: + ```python + @hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards={"cuda": 2, "rocm": 2, "npu": 2}, + ) + ``` + or + ```python + @hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards=2, + ) + ``` +- `res` must be a dict; supported resources: CUDA (L4/H100), ROCm (MI325), NPU (A2/A3) +- `num_cards` can be int (all platforms) or dict (per platform); defaults to 1 when missing +- `hardware_test` automatically applies `@create_new_process_for_each_test` for distributed tests. +- Distributed markers (`distributed_cuda`, `distributed_rocm`, `distributed_npu`) are auto-added for multi-card cases +- Filtering examples: + - CUDA only: `pytest -m "distributed_cuda and L4"` + - ROCm only: `pytest -m "distributed_rocm and MI325"` + - NPU only: `pytest -m "distributed_npu"` + +## Add Support for a New Platform + +If you want to add support for a new platform (e.g., "tpu" for a new accelerator), follow these steps: + +1. **Extend the marker list in your pytest config** so that platform/resource markers are defined: + ```toml + # In pyproject.toml or pytest.ini + [tool.pytest.ini_options] + markers = [ + # ... existing markers ... + "tpu: Tests that require TPU device", + "TPU_V3: Tests that require TPU v3 hardware", + "distributed_tpu: Tests that require multiple TPU devices", + ] + ``` +2. **Implement a marker construction function for your platform** in `vllm-omni/tests/utils.py`: + ```python + # In vllm-omni/tests/utils.py + + def tpu_marks(*, res: str, num_cards: int): + test_platform = pytest.mark.tpu + if res == "TPU_V3": + test_resource = pytest.mark.TPU_V3 + else: + raise ValueError( + f"Invalid TPU resource type: {res}. Supported: TPU_V3") + + if num_cards == 1: + return [test_platform, test_resource] + else: + test_distributed = pytest.mark.distributed_tpu(num_cards=num_cards) + # Optionally: add skipif_tpu when implemented + return [test_platform, test_resource, test_distributed] + ``` +3. **Update `hardware_test` to recognize your new platform**: + In the relevant place (see the `hardware_test` implementation), add: + ```python + if platform == "tpu": + marks = tpu_marks(res=resource, num_cards=cards) + ``` +4. **(Recommended) Add a test using your new markers**: + ```python + @hardware_test( + res={"tpu": "TPU_V3"}, + num_cards=2, + ) + def test_my_tpu_feature(): + ... + ``` + +**Summary**: +- Add pytest markers for your new platform/resources +- Implement a marker function (`xxx_marks`) +- Plug into `hardware_test` +- You're done: tests decorated with `@hardware_test` using your platform now automatically get the correct markers, distribution, and isolation! + +See code in `vllm-omni/tests/utils.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`). diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md new file mode 100644 index 0000000000000000000000000000000000000000..65c2b044346526157348892af90360ca70d05c3b --- /dev/null +++ b/docs/contributing/ci/tests_style.md @@ -0,0 +1,280 @@ +# Test File Structure and Style Guide + +To ensure project maintainability and sustainable development, we encourage contributors to submit test code (unit tests, system tests, or end-to-end tests) alongside their code changes. This document outlines the guidelines for organizing and naming test files. + +## Test Types + +### Unit Tests and System Tests +For unit tests and system tests, we strongly recommend placing test files in the same directory structure as the source code being tested, using the naming convention `test_*.py`. + +### End-to-End (E2E) Tests for Models +End-to-end tests verify the complete functionality of a system or component. For our project, the E2E tests for different omni models are organized into two subdirectories: + +- **`tests/e2e/offline_inference/`**: Tests for offline inference modes (e.g., Qwen3Omni offline inference) + +- **`tests/e2e/online_serving/`**: Tests for online serving scenarios (e.g., API server tests) + +**Example:** The test file for `vllm_omni/entrypoints/omni_llm.py` should be located at `tests/entrypoints/test_omni_llm.py`. + +## Test Directory Structure + +The ideal directory structure mirrors the source code organization: + +``` +vllm_omni/ tests/ +├── config/ → ├── config/ +│ └── model.py │ └── test_model.py +│ +├── core/ → ├── core/ +│ └── sched/ │ └── sched/ # Maps to core/sched/ +│ ├── omni_ar_scheduler.py │ ├── test_omni_ar_scheduler.py +│ ├── omni_generation_scheduler.py │ ├── test_omni_generation_scheduler.py +│ └── output.py │ └── test_output.py +│ +├── diffusion/ → ├── diffusion/ +│ ├── diffusion_engine.py │ ├── test_diffusion_engine.py +│ ├── omni_diffusion.py │ ├── test_omni_diffusion.py +│ ├── attention/ │ ├── attention/ # Maps to diffusion/attention/ +│ │ └── backends/ │ │ └── test_*.py +│ ├── models/ │ ├── models/ # Maps to diffusion/models/ +│ │ ├── qwen_image/ │ │ ├── qwen_image/ +│ │ │ └── ... │ │ │ └── test_*.py +│ │ └── z_image/ │ │ └── z_image/ +│ │ └── ... │ │ └── test_*.py +│ └── worker/ │ └── worker/ # Maps to diffusion/worker/ +│ └── ... │ └── test_*.py +│ +├── distributed/ → ├── distributed/ +│ └── ... │ └── test_*.py +│ +├── engine/ → ├── engine/ +│ ├── processor.py │ ├── test_processor.py +│ └── output_processor.py │ └── test_output_processor.py +│ +├── entrypoints/ → ├── entrypoints/ +│ ├── omni_llm.py │ ├── test_omni_llm.py # UT: OmniLLM core logic (mocked) +│ ├── omni_stage.py │ ├── test_omni_stage.py # UT: OmniStage logic +│ ├── omni.py │ ├── test_omni.py # E2E: Omni class (offline inference) +│ ├── async_omni.py │ ├── test_async_omni.py # E2E: AsyncOmni class +│ ├── cli/ │ ├── cli/ # Maps to entrypoints/cli/ +│ │ └── ... │ │ └── test_*.py +│ └── openai/ │ └── openai/ # Maps to entrypoints/openai/ +│ ├── api_server.py │ ├── test_api_server.py # E2E: API server (online serving) +│ └── serving_chat.py │ └── test_serving_chat.py +│ +├── inputs/ → ├── inputs/ +│ ├── data.py │ ├── test_data.py +│ ├── parse.py │ ├── test_parse.py +│ └── preprocess.py │ └── test_preprocess.py +│ +├── model_executor/ → ├── model_executor/ +│ ├── layers/ │ ├── layers/ +│ │ └── mrope.py │ │ └── test_mrope.py +│ ├── model_loader/ │ ├── model_loader/ +│ │ └── weight_utils.py │ │ └── test_weight_utils.py +│ ├── models/ │ ├── models/ +│ │ ├── qwen2_5_omni/ │ │ ├── qwen2_5_omni/ +│ │ │ ├── qwen2_5_omni_thinker.py │ │ │ ├── test_qwen2_5_omni_thinker.py # UT +│ │ │ ├── qwen2_5_omni_talker.py │ │ │ ├── test_qwen2_5_omni_talker.py # UT +│ │ │ └── qwen2_5_omni_token2wav.py │ │ │ └── test_qwen2_5_omni_token2wav.py # UT +│ │ └── qwen3_omni/ │ │ └── qwen3_omni/ +│ │ └── ... │ │ └── test_*.py +│ ├── stage_configs/ │ └── stage_configs/ # Configuration tests (if needed) +│ │ └── ... │ └── test_*.py +│ └── stage_input_processors/ │ └── stage_input_processors/ +│ └── ... │ └── test_*.py +│ +├── sample/ → ├── sample/ +│ └── ... │ └── test_*.py +│ +├── utils/ → ├── utils/ +│ └── platform_utils.py │ └── test_platform_utils.py +│ +├── worker/ → ├── worker/ + ├── gpu_ar_worker.py │ ├── test_gpu_ar_worker.py + ├── gpu_generation_worker.py │ ├── test_gpu_generation_worker.py + ├── gpu_model_runner.py │ ├── test_gpu_model_runner.py + └── npu/ │ └── npu/ # Maps to worker/npu/ + └── ... │ └── test_*.py +│ +└── e2e/ → ├── e2e/ # End-to-end scenarios (no 1:1 source mirror) + ├── online_serving/ # Full-stack online serving flows + │ └── (empty for now) + └── offline_inference/ # Full offline inference flows + ├── test_qwen2_5_omni.py # Moved from multi_stages/ + ├── test_qwen3_omni.py # Moved from multi_stages_h100/ + ├── test_t2i_model.py # Moved from single_stage/ + └── stage_configs/ # Shared stage configs + ├── qwen2_5_omni_ci.yaml + └── qwen3_omni_ci.yaml +``` + + + +### Naming Conventions + +- **Unit Tests**: Use `test_.py` format. Example: `omni_llm.py` → `test_omni_llm.py` + +- **E2E Tests**: Place in `tests/e2e/offline_inference/` or `tests/e2e/online_serving/` with descriptive names. Example: `tests/e2e/offline_inference/test_qwen3_omni.py`, `tests/e2e/offline_inference/test_diffusion_model.py` + +### Best Practices + +1. **Mirror Source Structure**: Test directories should mirror the source code structure +2. **Test Type Indicators**: Use comments to indicate test types (UT for unit tests, E2E for end-to-end tests) +3. **Shared Resources**: Place shared test configurations (e.g., CI configs) in appropriate subdirectories +4. **Consistent Naming**: Follow the `test_*.py` naming convention consistently across all test files + + +## Test codes requirements + +### Coding style + +1. **File header**: Add SPDX license header to all test files +2. **Imports**: Pls don't use manual `sys.path` modifications, use standard imports instead. +3. **Test type differentiation**: + + - Unit tests: Maintain mock style + - E2E tests for models: Consider using OmniRunner uniformly, avoid decorators + +4. **Documentation**: Add docstrings to all test functions +5. **Environment variables**: Set uniformly in `conftest.py` or at the top of files +6. **Type annotations**: Add type annotations to all test function parameters +7. **Pytest Markers**: Add necessary markers like `@pytest.mark.core_model` and use `@hardware_test` to declare hardware requirements (check detailed in [Markers for Tests](../ci/tests_markers.md)). + +### Template +#### E2E - Online serving + +```python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Online E2E smoke test for an omni model (video,text,audio → audio). +""" +from pathlib import Path + +import pytest +import openai + +from tests.utils import hardware_test + +# Optional: set process start method for workers +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +models = ["{your model name}"] #Edit here to load your model +stage_configs = [str(Path(__file__).parent / "stage_configs" / {your model yaml})] #Edit here to load your model yaml +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + +#OmniServer,Used to start the vllm-omni server +class OmniServer: + xxx + + +@pytest.fixture +def omni_server(request): + model, stage_config_path = request.param + with OmniServer(model, ["--stage-configs-path", stage_config_path]) as server: + yield server + + +#handle request message +@pytest.fixture(scope="session") +def base64_encoded_video() -> str: + xxx + +@pytest.fixture(scope="session") +def dummy_messages_from_video_data(video_data_url: str, content_text: str) -> str: + xxx + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards={"cuda": 2, "rocm": 2, "npu": 4}, +) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_video_to_audio( + client: openai.OpenAI, + omni_server, + base64_encoded_video: str, +) -> None: + #set message + video_data_url = f"data:video/mp4;base64, {base64_encoded_video}" + messages = dummy_messages_from_video_data(video_data_url) + + #send request + chat_completion = client.chat.completions.create( + model=omni_server.model, + messages=messages, + ) + + #verify text output + text_choice = chat_completion.choices[0] + assert text_choice.finish_reason == "length" + + #verify audio output + audio_choice = chat_completion.choices[1] + audio_message = audio_choice.message + if hasattr(audio_message, "audio") and audio_message.audio: + assert audio_message.audio.data is not None + assert len(audio_message.audio.data) > 0 +``` + +#### E2E - Offline inference +```python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Offline E2E smoke test for an omni model (video → audio). +""" + +import os +from pathlib import Path + +import pytest +from vllm.assets.video import VideoAsset + +from tests.utils import hardware_test +from ..multi_stages.conftest import OmniRunner + +# Optional: set process start method for workers +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +models = ["{your model name}"] #Edit here to load your model +stage_configs = [str(Path(__file__).parent / "stage_configs" / {your model yaml})] #Edit here to load your model yaml + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + +# function name: test_{input_modality}_to_{output_modality} +# modality candidate: text, image, audio, video, mixed_modalities +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards=2, +) +@pytest.mark.parametrize("test_config", test_params) +def test_video_to_audio(omni_runner: type[OmniRunner], model: str) -> None: + """Offline inference: video input, audio output.""" + model, stage_config_path = test_config + with omni_runner(model, seed=42, stage_configs_path=stage_config_path) as runner: + # Prepare inputs + video = VideoAsset(name="sample", num_frames=4).np_ndarrays + + outputs = runner.generate_multimodal( + prompts="Describe this video briefly.", + videos=video, + ) + + # Minimal assertions: got outputs and at least one audio result + assert outputs + has_audio = any(o.final_output_type == "audio" for o in outputs) + assert has_audio +``` + +## Checklist before submitting your test files + +1. The file is saved in an appropriate place and the file name is clear. +2. The coding style follows the requirements outlined above. +3. **All test functions have appropriate pytest markers** +4. For tests that need run in CI, please ensure the test is configured under the `./buildkite/` folder. diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b3e951c8bfe0071345e0618bb01bfa05d6217473 --- /dev/null +++ b/docs/contributing/model/README.md @@ -0,0 +1,15 @@ +# Adding a New Model + +This section provides comprehensive guidance on how to add a new model to vLLM-Omni. + +## Documentation + +- **[Adding an Omni-Modality Model](adding_omni_model.md)**: Complete step-by-step guide using Qwen3-Omni as an example. + +- **[Adding a Diffusion Model](adding_diffusion_model.md)**: Complete step-by-step guide using Qwen/Qwen-Image-Edit as an example. + + + +## Quick Start + +For a quick reference, see the [Adding a New Multi-Stage Model Guide](adding_omni_model.md) and [Adding a New Diffusion Model Guide](adding_diffusion_model.md). diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md new file mode 100644 index 0000000000000000000000000000000000000000..80212360fa36d350daf93e25b4ce9cde60bcffbe --- /dev/null +++ b/docs/contributing/model/adding_diffusion_model.md @@ -0,0 +1,371 @@ +# Adding a Diffusion Model +This guide walks through the process of adding a new Diffusion model to vLLM-Omni, using Qwen/Qwen-Image-Edit as a comprehensive example. + +# Table of Contents +1. [Overview](#overview) +2. [Directory Structure](#directory-structure) +3. [Step-by-Step Implementation](#step-by-step-implementation) +4. [Testing](#testing) +5. [Adding a Model Recipe](#adding-a-model-recipe) + + +# Overview +When add a new diffusion model into vLLM-Omni, additional adaptation work is required due to the following reasons: + ++ New model must follow the framework’s parameter passing mechanisms and inference flow. + ++ Replacing the model’s default implementations with optimized modules, which is necessary to achieve the better performance. + +The diffusion execution flow as follow: +

+ + + Diffusion Flow + +

+ + +# Directory Structure +File Structure for Adding a New Diffusion Model + +``` +vllm_omni/ +└── examples/ + └──offline_inference + └── example script # reuse existing if possible (e.g., image_edit.py) + └──online_serving + └── example script +└── diffusion/ + └── registry.py # Registry work + ├── request.py # Request Info + └── models/your_model_name/ # Model directory (e.g., qwen_image) + └── pipeline_xxx.py # Model implementation (e.g., pipeline_qwen_image_edit.py) +``` + +# Step-by-step-implementation +## Step 1: Model Implementation +The diffusion pipeline’s implementation follows **HuggingFace Diffusers**. +### 1.1 Define the Pipeline Class +Define the pipeline class, e.g., `QwenImageEditPipeline`, and initialize all required submodules, either from HuggingFace `diffusers` or custom implementations. In `QwenImageEditPipeline`, only `QwenImageTransformer2DModel` is re-implemented to support optimizations such as Ulysses-SP. When adding new models in the future, you can either reuse this re-implemented `QwenImageTransformer2DModel` or extend it as needed. + +### 1.2 Pre-Processing and Post-Processing Extraction +Extract the pre-processing and post-processing logic from the pipeline class to follow vLLM-Omni’s execution flow. For Qwen-Image-Edit: +```python +def get_qwen_image_edit_pre_process_func( + od_config: OmniDiffusionConfig, +): + """ + Define a pre-processing function that resizes input images and + pre-process for subsequent inference. + """ +``` + +```python +def get_qwen_image_edit_post_process_func( + od_config: OmniDiffusionConfig, +): + """ + Defines a post-processing function that post-process images. + """ +``` + +### 1.3 Define the forward function +The forward function of `QwenImageEditPipeline` follows the HuggingFace `diffusers` design for the most part. The key differences are: ++ As described in the overview, arguments are passed through `OnniDiffusionRequest`, so we need to get user parameters from it accordingly. +```python +prompt = req.prompt +``` ++ pre/post-processing are handled by the framework elsewhere, so skip them. + +### 1.4 Replace some ops or layers in DiT component + +vLLM-Omni provides a set of optimized operators with better performance and built-in support for parallelism, including attention, rotary embeddings (RoPE), and linear layers. + +Below is an example showing how to replace standard Transformer attention and FFN layers with vLLM-Omni implementations: + +```python +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +class MyAttention(nn.Module): + def __init__(self): + super().__init__() + self.attn = Attention() + self.to_qkv = QKVParallelLinear() + self.to_out = RowParallelLinear() + self.rope = RotaryEmbedding(is_neox_style=False) + + def forward(self, hidden_states): + qkv, _ = self.to_qkv(hidden_states) + q, k, v = qkv.split(...) + q, k = self.rope(...) + attn_output = self.attn(q, k, v) + output = self.to_out(attn_output) + +class MyFFN(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = ColumnParallelLinear() + self.fc2 = RowParallelLinear() + self.act = F.gelu + + def forward(self, hidden_states): + hidden, _ = self.fc1(hidden_states) + hidden = self.act(hidden) + output = self.fc2(hidden) + return output +``` + +In this example: + ++ Attention uses vLLM-Omni’s optimized attention kernel together with parallel QKV projection and RoPE. + ++ Linear layers are replaced with column- and row-parallel variants to enable tensor parallelism. + ++ The FFN follows a standard two-layer structure and can be further optimized (e.g., using fused or merged projections) if needed. + + +### 1.5 Provide a `_repeated_blocks` in DiT model +`_repeated_blocks` is the small and frequently-repeated block(s) of a model -- typically a transformer layer. + +It's used for torch compile optimizations. +```python +_repeated_blocks = ["QwenImageTransformerBlock"] +``` + + +### 1.6 (Optional) implement sequence parallelism +vLLM-Omni has a non-intrusive `_sp_plan` that enable sequence parallel without modifying `forward()` logic. +You can refer to [How to parallelize a new model](../../user_guide/diffusion/parallelism_acceleration.md) + + +### 1.7 (Optional) integrate with Cache-Dit +vLLM-Omni supports acceleration via [Cache-Dit](../../user_guide/diffusion/cache_dit_acceleration.md). Most models compatible with Diffusers can use Cache-Dit seamlessly. For new models, you can extend support by modifying`cache_dit_backend.py` + +## Step 2: Extend OmniDiffusionRequest Fields +User-provided inputs are ultimately passed to the model’s forward method through OmniDiffusionRequest, so we add the required fields here to support the new model. +```python +prompt: str | list[str] | None = None +negative_prompt: str | list[str] | None = None +... +``` + +## Step 3: Registry ++ registry diffusion model in registry.py +```python +_DIFFUSION_MODELS = { + # arch:(mod_folder, mod_relname, cls_name) + ... + "QwenImageEditPipeline": ( + "qwen_image", + "pipeline_qwen_image_edit", + "QwenImageEditPipeline", + ), + ... +} +``` ++ registry pre-process get function +```python +_DIFFUSION_PRE_PROCESS_FUNCS = { + # arch: pre_process_func + ... + "QwenImageEditPipeline": "get_qwen_image_edit_pre_process_func", + ... +} +``` + ++ registry post-process get function +```python +_DIFFUSION_POST_PROCESS_FUNCS = { + # arch: post_process_func + ... + "QwenImageEditPipeline": "get_qwen_image_edit_post_process_func", + ... +} +``` + +## Step 4: Add an Example Script +For each newly integrated model, we need to provide examples script under the examples/ to demonstrate how to initialize the pipeline with Omni, pass in user inputs, and generate outputs. +Key point for writing the example: + ++ Use the Omni entrypoint to load the model and construct the pipeline. + ++ Show how to format user inputs and pass them via omni.generate(...). + ++ Demonstrate the common runtime arguments, such as: + + + model path or model name + + + input image(s) or prompt text + + + key diffusion parameters (e.g., inference steps, guidance scale) + + + optional acceleration backends (e.g., Cache-DiT, TeaCache) + ++ Save or display the generated results so users can validate the integration. + +## Step 5: TeaCache Coefficient Estimation (Optional) + +If your model supports TeaCache acceleration, you need to estimate the polynomial coefficients for optimal caching performance. + +### 5.1 Add Extractor Function + +First, implement an extractor function in `vllm_omni/diffusion/cache/teacache/extractors.py`. The extractor extracts the modulated input and defines how to run transformer blocks: + +```python +def extract_your_model_context( + module: nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + **kwargs: Any, +) -> CacheContext: + # 1. Preprocessing + temb = module.time_embed(timestep) + + # 2. Extract modulated input (for cache decision) + modulated_input = module.transformer_blocks[0].norm1(hidden_states, temb) + + # 3. Define transformer execution + def run_transformer_blocks(): + h = hidden_states + for block in module.transformer_blocks: + h = block(h, temb=temb) + return (h,) + + # 4. Define postprocessing + def postprocess(h): + return module.proj_out(module.norm_out(h, temb)) + + return CacheContext( + modulated_input=modulated_input, + hidden_states=hidden_states, + encoder_hidden_states=None, + temb=temb, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) +``` + +Register it in `EXTRACTOR_REGISTRY`: +```python +EXTRACTOR_REGISTRY = { + ... + "YourTransformer2DModel": extract_your_model_context, +} +``` + +### 5.2 Add Adapter for Coefficient Estimation + +Add an adapter in `vllm_omni/diffusion/cache/teacache/coefficient_estimator.py`: + +```python +class YourModelAdapter: + @staticmethod + def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: + # Load your pipeline + ... + + @staticmethod + def get_transformer(pipeline: Any) -> tuple[Any, str]: + return pipeline.transformer, "YourTransformer2DModel" + + @staticmethod + def install_hook(transformer: Any, hook: DataCollectionHook) -> None: + registry = HookRegistry.get_or_create(transformer) + registry.register_hook(hook._HOOK_NAME, hook) + +_MODEL_ADAPTERS["YourModel"] = YourModelAdapter +``` + +### 5.3 Run Coefficient Estimation + +Use the provided script to estimate coefficients: + +```python +from vllm_omni.diffusion.cache.teacache.coefficient_estimator import ( + TeaCacheCoefficientEstimator, +) +from datasets import load_dataset +from tqdm import tqdm + +# Load model +estimator = TeaCacheCoefficientEstimator( + model_path="/path/to/model", + model_type="Bagel", # Your model type + device="cuda", +) + +# Load prompts (paper suggests ~70 prompts) +dataset = load_dataset("nateraw/parti-prompts", split="train") +prompts = dataset["Prompt"][:70] + +# Collect data +for prompt in tqdm(prompts): + estimator.collect_from_prompt(prompt, num_inference_steps=50) + +# Estimate coefficients +coeffs = estimator.estimate(poly_order=4) +print(f"Coefficients: {coeffs}") +``` + +### 5.4 Interpreting Coefficient Estimation Results + +The estimator outputs statistics and polynomial coefficients. Here's how to interpret them: + +**Example Output:** +``` +Data statistics: +Count: 48 +Input Diffs (x): min=1.1089e-02, max=5.2555e-02, mean=2.8435e-02 +Output Diffs (y): min=2.8242e-02, max=2.9792e-01, mean=7.0312e-02 +Coefficients: [1333131.29, -168644.23, 7950.51, -163.75, 1.26] +``` + +**What to Check:** +- **Count**: Number of timestep pairs analyzed. Should be at least 30-50 for reliable estimation. Low count suggests insufficient prompts or inference steps. +- **Input/Output Ranges**: Verify output differences correlate with input differences. If ranges seem unusual, check your prompt diversity. +- **Coefficient Magnitude**: Extremely large values (>1e8) may indicate numerical instability - try collecting more diverse data. + +**Troubleshooting:** +- If results seem unreliable, try: + - Increasing number of prompts (100+ recommended) + - Using more diverse prompts from multiple datasets + - Adjusting `num_inference_steps` (try 20, 50, 100) + +### 5.5 Add Coefficients to Config + +Add the estimated coefficients to `vllm_omni/diffusion/cache/teacache/config.py`: + +```python +_MODEL_COEFFICIENTS = { + ... + "YourTransformer2DModel": [ + 1.04730573e+06, # a4 + -1.34150749e+05, # a3 + 6.51517806e+03, # a2 + -1.41209108e+02, # a1 + 1.17241808e+00, # a0 + ], +} +``` +## Step 6: Open a Pull Request + +When submitting a pull request to add support for a new model, please include the following information in the PR description: + ++ Output verification: provide generation outputs to verify correctness and model behavior. + ++ Inference speed: provide a comparison with the corresponding implementation in Diffusers. + ++ Parallelism support: specify the supported parallel sizes and any relevant limitations. + ++ Cache acceleration: check whether the model can be accelerated using Cache-Dit or not. + + +Providing these details helps reviewers evaluate correctness, performance improvements, and parallel scalability of the new model integration. + +# Testing +For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md). + + +## Adding a Model Recipe +After implementing and testing your model, please add a model recipe to the [vllm-project/recipes](https://github.com/vllm-project/recipes) repository. This helps other users understand how to use your model with vLLM-Omni. diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md new file mode 100644 index 0000000000000000000000000000000000000000..81499118623b78e363763c59d5a78b8a36d335d6 --- /dev/null +++ b/docs/contributing/model/adding_omni_model.md @@ -0,0 +1,633 @@ +# Adding an Omni-Modality Model + +This guide walks through the process of adding a new multi-stage model to vLLM-Omni, using **Qwen3-Omni** as a comprehensive example. Qwen3-Omni is a multi-stage omni-modality model that demonstrates the full capabilities of vLLM-Omni's architecture. + +## Table of Contents + +1. [Overview](#overview) +2. [Directory Structure](#directory-structure) +3. [Step-by-Step Implementation](#step-by-step-implementation) +4. [Key Components](#key-components) +5. [Model Registration](#model-registration) +6. [Stage Configuration](#stage-configuration) +7. [Stage Input Processors](#stage-input-processors) +8. [Testing](#testing) +9. [Adding a Model Recipe](#adding-a-model-recipe) +10. [Summary](#summary) + +## Overview + +vLLM-Omni supports multi-stage model architectures where different stages can run on different devices and process different modalities. The Qwen3-Omni model exemplifies this with three stages: + +1. **Thinker Stage**: Multimodal understanding (text + audio + video) → text generation +2. **Talker Stage**: Text embeddings → RVQ codec codes +3. **Code2Wav Stage**: RVQ codes → audio waveform + +Each stage is implemented as a separate model class that can be configured independently. + +## Directory Structure + +When adding a new model, you'll need to create the following structure: + +``` +vllm_omni/model_executor/models/ +└── your_model_name/ # Model directory (e.g., qwen3_omni) + ├── __init__.py # Exports main model class + ├── your_model.py # Main unified model class + ├── your_model_stage1_implementation.py # Stage 1 implementation (e.g., thinker) + ├── your_model_stage2_implementation.py # Stage 2 implementation (e.g., talker) + └── your_model_stage3_implementation.py # Stage 3 implementation (e.g., code2wav) + └── ... maybe other stage implementations + +vllm_omni/model_executor/stage_input_processors/ +└── your_model_name.py # Stage transition processors + +vllm_omni/model_executor/stage_configs/ +└── your_model_name.yaml # Stage configuration file +``` + +## Step-by-Step Implementation + +### Step 1: Create the Model Directory + +Create a new directory under `vllm_omni/model_executor/models/` + +### Step 2: Implement Stage Components + +For Qwen3-Omni, we have three stage components: + +#### 2.1 Thinker Stage (`qwen3_omni_moe_thinker.py`) + +The thinker stage handles multimodal understanding. Key features: + +- Inherits from base Qwen3 MoE model in vLLM, using vLLM fused ops & page attn to accelerate +- Implements multimodal processing interfaces +- Handles audio, video, and image inputs +- Generates text outputs + +```python +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP +from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM + +class Qwen3OmniMoeThinkerForConditionalGeneration( + Qwen3MoeForCausalLM, + SupportsMultiModal, + SupportsPP +): + """Thinker stage: multimodal understanding → text generation.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Initialize base model + # Set up multimodal processors + # Configure audio/video/image encoders + pass +``` + +#### 2.2 Talker Stage (`qwen3_omni_moe_talker.py`) + +The talker stage converts text embeddings to codec codes: + +```python +class Qwen3OmniMoeTalkerForConditionalGeneration( + Qwen3MoeForCausalLM, + SupportsPP +): + """Talker stage: text embeddings → RVQ codec codes.""" + + def __init__(self, vllm_config, talker_config, prefix): + # Initialize base model + # Replace LM head with codec head + # Set up text projection from thinker + pass +``` + +#### 2.3 Code2Wav Stage (`qwen3_omni_code2wav.py`) + +The code2wav stage generates audio waveforms: + +```python +class Qwen3OmniMoeCode2Wav(nn.Module): + """Code2Wav stage: RVQ codes → audio waveform.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Initialize audio decoder + # Set up codec processing + pass +``` + +### Step 3: Implement the Unified Model Class + +The main model class (`qwen3_omni.py`) orchestrates all stages: + +```python +@MULTIMODAL_REGISTRY.register_processor( + Qwen3OmniMoeThinkerMultiModalProcessor, + info=Qwen3OmniMoeThinkerProcessingInfo, + dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder, +) +class Qwen3OmniMoeForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, Qwen3OmniMoeConditionalGenerationMixin +): + """ + Unified Qwen3 Omni MoE model combining thinker, talker, and code2wav. + + Architecture: + - Thinker: Multimodal understanding (text + audio + video) → text generation + - Talker: Text embeddings → RVQ codec codes + - Code2Wav: RVQ codes → audio waveform + + Usage: + Set `model_stage` in vllm_config to one of: "thinker", "talker", "code2wav" + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.have_multimodal_outputs = True + config: Qwen3OmniMoeConfig = vllm_config.model_config.hf_config + + # Determine which stage to initialize + self.model_stage = vllm_config.model_config.model_stage + + if self.model_stage == "thinker": + # Initialize thinker model + thinker_vllm_config = vllm_config.with_hf_config( + config.thinker_config, + architectures=["Qwen3OmniMoeThinkerForConditionalGeneration"] + ) + self.thinker = init_vllm_registered_model( + vllm_config=thinker_vllm_config, + prefix=maybe_prefix(prefix, "thinker"), + hf_config=config.thinker_config, + architectures=["Qwen3OmniMoeThinkerForConditionalGeneration"], + ) + self.model = self.thinker + + elif self.model_stage == "talker": + # Initialize talker model + talker_vllm_config = vllm_config.with_hf_config( + config.talker_config, + architectures=["Qwen3OmniMoeTalkerForConditionalGeneration"] + ) + self.talker = init_vllm_registered_model( + vllm_config=talker_vllm_config, + prefix=maybe_prefix(prefix, "talker"), + hf_config=config.talker_config, + architectures=["Qwen3OmniMoeTalkerForConditionalGeneration"], + ) + self.model = self.talker + + elif self.model_stage == "code2wav": + # Initialize code2wav model + code2wav_vllm_config = vllm_config.with_hf_config( + config.code2wav_config, + architectures=["Qwen3OmniMoeCode2Wav"] + ) + self.code2wav = init_vllm_registered_model( + vllm_config=code2wav_vllm_config, + prefix=maybe_prefix(prefix, "code2wav"), + hf_config=config.code2wav_config, + architectures=["Qwen3OmniMoeCode2Wav"], + ) + self.model = self.code2wav + else: + raise ValueError( + f"Invalid model_stage: {self.model_stage}. " + f"Must be one of: 'thinker', 'talker', 'code2wav'" + ) +``` + +#### Key Methods to Implement + +1. **`forward()`**: Handles the forward pass for each stage +2. **`embed_input_ids()`**: Embeds input token IDs +3. **`embed_multimodal()`**: Processes multimodal inputs (if applicable) +4. **`compute_logits()`**: Computes logits from hidden states +5. **`load_weights()`**: Loads model weights with proper prefixing of different stages + +### Step 4: Create `__init__.py` + +Export the main model class: + +```python +# vllm_omni/model_executor/models/qwen3_omni/__init__.py +from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration + +__all__ = ["Qwen3OmniMoeForConditionalGeneration"] +``` + +## Key Components + +### 1. Model Interfaces + +Your model should implement the appropriate interfaces: + +- **`SupportsMultiModal`**: For models that process multimodal inputs +- **`SupportsPP`**: For models that support pipeline parallelism +- **`SupportsMRoPE`**: For models using multi-dimensional RoPE (if applicable) + +### 2. Multimodal Registration + +If your model processes multimodal inputs, register it with the multimodal registry: + +```python +@MULTIMODAL_REGISTRY.register_processor( + YourMultiModalProcessor, + info=YourProcessingInfo, + dummy_inputs=YourDummyInputsBuilder, +) +class YourModel(nn.Module, SupportsMultiModal): + pass +``` + +### 3. Weight Loading + +Implement `load_weights()` to handle weight loading with proper prefixing: + +```python +def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for all components of the omni model.""" + loaded_weights = set() + thinker_weights = [] + talker_weights = [] + code2wav_weights = [] + + # Separate weights by component + for k, v in weights: + if k.startswith("thinker."): + thinker_weights.append((k, v)) + elif k.startswith("talker."): + talker_weights.append((k, v)) + elif k.startswith("code2wav."): + code2wav_weights.append((k, v)) + + # Load each component's weights + if self.thinker and thinker_weights: + thinker_loaded = self.thinker.load_weights(thinker_weights) + thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker") + loaded_weights.update(thinker_loaded) + + # Similar for talker and code2wav... + + return loaded_weights +``` + +### 4. Output Format + +Use `OmniOutput` for stage outputs: + +```python +from vllm_omni.model_executor.models.output_templates import OmniOutput + +# In forward method +return OmniOutput( + text_hidden_states=hidden_states, + multimodal_outputs={"additional_data": data}, + next_token_id=next_token_id, +) +``` + +## Model Registration + +Register your model in `vllm_omni/model_executor/models/registry.py`: + +```python +_OMNI_MODELS = { + # ... existing models ... + + # Your new model + "YourModelForConditionalGeneration": ( + "your_model_name", # Module folder name + "your_model", # Module file name (without .py) + "YourModelForConditionalGeneration", # Class name + ), + "YourModelThinkerForConditionalGeneration": ( + "your_model_name", + "your_model_thinker", + "YourModelThinkerForConditionalGeneration", + ), + # ... other stages ... +} +``` + +The registry uses lazy loading, so the model class is imported only when needed. + +## Stage Configuration + +Create a YAML configuration file in `vllm_omni/model_executor/stage_configs/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml). + +### Key Configuration Fields + +- **`model_stage`**: Which stage to run ("thinker", "talker", "code2wav", etc.) +- **`model_arch`**: The model architecture name (must match registry) +- **`engine_input_source`**: List of stage IDs that provide input to this stage +- **`custom_process_input_func`**: Function to process inputs from previous stages +- **`final_output`**: Whether this stage produces the final output (True/False) +- **`final_output_type`**: Type of final output ("text", "audio", "image", etc.) + +## Stage Input Processors + +Stage transitions are the mechanism by which outputs from one stage are converted into inputs for the next stage. This section explains where and how stage transitions occur. + +### Where Stage Transitions Are Called + +Stage transitions happen automatically in the orchestrator (`OmniLLM` class) during the generation loop. Here's the detailed flow: + +1. **Location**: `vllm_omni/entrypoints/omni_llm.py` in the `_run_generation()` method +2. **Trigger**: When a stage completes processing and produces outputs +3. **Execution Flow**: + ```python + # In omni_llm.py, _run_generation() method (around line 345-460) + + # Main orchestrator loop polls each stage for completed requests + for stage_id, stage in enumerate(self.stage_list): + result = stage.try_collect() # Get completed request + if result is None: + continue + + # Store outputs from this stage + engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + stage.set_engine_outputs(engine_outputs) + + # Check if there's a next stage to forward to + next_stage_id = stage_id + 1 + if next_stage_id < len(self.stage_list): + next_stage: OmniStage = self.stage_list[next_stage_id] + + # THIS IS WHERE STAGE TRANSITION HAPPENS + next_inputs = next_stage.process_engine_inputs( + self.stage_list, + [request_id_to_prompt[req_id]] + ) + + # Submit to next stage + task = { + "type": OmniStageTaskType.GENERATE, + "request_id": req_id, + "engine_inputs": next_inputs[0], + "sampling_params": sampling_params_list[next_stage_id], + } + next_stage.submit(task) + ``` + +### How Stage Transitions Work + +The stage transition process follows these steps: + +1. **Stage Completion**: When a stage finishes processing a request, it stores outputs via `stage.set_engine_outputs(engine_outputs)` + +2. **Transition Detection**: The orchestrator checks if there's a next stage and calls `process_engine_inputs()` on it + +3. **Input Processing**: The `process_engine_inputs()` method in `OmniStage` (`omni_stage.py`) handles the transition: + ```python + def process_engine_inputs( + self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None + ) -> list[OmniTokensPrompt | TextPrompt]: + """Process engine inputs for this stage from upstream stage outputs.""" + + if self.custom_process_input_func is None: + # Default behavior: pass token IDs directly + # Extract outputs from source stage + source_stage_id = self.engine_input_source[0] + source_outputs = stage_list[source_stage_id].engine_outputs + # ... create OmniTokensPrompt from token_ids ... + else: + # Custom transition function (YOUR CODE HERE) + return self.custom_process_input_func( + stage_list, + self.engine_input_source, + prompt, + self.requires_multimodal_data + ) + ``` + - If `custom_process_input_func` is configured, it calls that function + - Otherwise, it uses default behavior (passing token IDs directly) + +4. **Custom Function Execution**: Your custom function receives: + - `stage_list`: List of all stage objects (to access upstream stage outputs) + - `engine_input_source`: List of source stage IDs (e.g., `[0]` for stage 0) + - `prompt`: Original prompt data (for preserving multimodal data) + - `requires_multimodal_data`: Whether multimodal data is required + +5. **Output Format**: The function must return a list of `OmniTokensPrompt` objects ready for the next stage + +### Data Structures in Stage Transitions + +Understanding the data structures is crucial for implementing stage transitions: + +**Input to your function:** +- `stage_list[source_stage_id].engine_outputs`: List of `EngineCoreOutput` objects + - Each contains `outputs`: List of `RequestOutput` objects + - Each `RequestOutput` has: + - `token_ids`: Generated token IDs + - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc. + - These are the hidden states or intermediate outputs from the model's forward pass + - `prompt_token_ids`: Original prompt token IDs + +**Output from your function:** +- Must return `list[OmniTokensPrompt]` where each `OmniTokensPrompt` contains: + - `prompt_token_ids`: List[int] - Token IDs for the next stage + - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states) + - `multi_modal_data`: Optional multimodal data if needed + +### How Model Outputs Are Stored + +The model's `forward()` method returns an `OmniOutput` object that contains: +- `text_hidden_states`: Final hidden states for text generation +- `multimodal_outputs`: Dict containing intermediate outputs + +These outputs are captured during the forward pass and stored in `multimodal_output` with specific keys: + +```python +# In your model's forward() method (e.g., qwen3_omni.py) +def forward(self, ...): + # ... processing ... + + # For thinker stage: capture embeddings and hidden states + multimodal_outputs = { + "0": captured_embeddings, # Layer 0 embeddings + "24": captured_hidden_states, # Layer 24 hidden states + "tts_bos_embed": tts_bos_embed, + "tts_eos_embed": tts_eos_embed, + # ... other intermediate outputs ... + } + + return OmniOutput( + text_hidden_states=hidden_states, + multimodal_outputs=multimodal_outputs, + ) +``` + +These keys are then accessible in your stage transition function: +```python +# In stage_input_processors/qwen3_omni.py +thinker_embeddings = output.multimodal_output["0"] # Access by key +thinker_hidden_states = output.multimodal_output["24"] +``` + +### Key Points + +1. **Accessing Upstream Outputs**: Use `stage_list[source_stage_id].engine_outputs` to get outputs from the source stage +2. **Extracting Data**: Access `output.multimodal_output[key]` to get specific hidden states or intermediate results + - Keys are defined by your model's `forward()` method when it creates `multimodal_outputs` +3. **Device Management**: Move tensors to appropriate devices (CPU for serialization, GPU for processing) +4. **Shape Transformations**: Reshape tensors as needed for the next stage (e.g., flattening codec codes) +5. **Batch Handling**: Process each request in the batch separately and return a list + +### Complete Flow Diagram + +

+ + + Data Flow between stages + +

+ +### Implementation Example + +Create stage transition processors in `vllm_omni/model_executor/stage_input_processors/your_model_name.py`: + +```python +# qwen3_omni.py + +def thinker2talker( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | None = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """ + Process thinker outputs to create talker inputs. + + Args: + stage_list: List of stage objects + engine_input_source: Source stage IDs (typically [0] for thinker) + prompt: Original prompt data + + Returns: + List of OmniTokensPrompt for talker stage + """ + source_stage_id = engine_input_source[0] + thinker_outputs = stage_list[source_stage_id].engine_outputs + talker_inputs = [] + + for thinker_output in thinker_outputs: + output = thinker_output.outputs[0] + # Extract thinker embeddings and hidden states + thinker_embeddings = output.multimodal_output["0"].float().clone().detach().cuda() + thinker_hidden_states = output.multimodal_output["24"].float().clone().detach().cuda() + + info = { + "thinker_embeddings": thinker_embeddings, + "thinker_hidden_states": thinker_hidden_states, + "thinker_sequences": thinker_output.prompt_token_ids + output.token_ids, + "thinker_input_ids": thinker_output.prompt_token_ids, + } + + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * computed_length, + additional_information=info, + multi_modal_data=None, + ) + ) + + return talker_inputs + + +def talker2code2wav( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | None = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """ + Process talker outputs to create code2wav inputs. + """ + source_stage_id = engine_input_source[0] + talker_outputs = stage_list[source_stage_id].engine_outputs + code2wav_inputs = [] + + for talker_output in talker_outputs: + output = talker_output.outputs[0] + # Extract codec codes + codec_codes = ( + output.multimodal_output["code_predictor_codes"] + .to(torch.long) + .transpose(0, 1) + .cpu() + .to(torch.long) + .reshape(-1) + .tolist() + ) + + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=codec_codes, + multi_modal_data=None, + ) + ) + + return code2wav_inputs +``` + +## Testing + +For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md). + +## Adding a Model Recipe + +After implementing and testing your model, please add a model recipe to the [vllm-project/recipes](https://github.com/vllm-project/recipes) repository. This helps other users understand how to use your model with vLLM-Omni. + +### What to Include + +Your recipe should include: + +1. **Model Overview**: Brief description of the model and its capabilities +2. **Installation Instructions**: Step-by-step setup instructions including: + - Installing vllm-omni and dependencies + - Installing any additional required packages (e.g., xformers, diffusers) + - Any version requirements +3. **Usage Examples**: Command-line examples demonstrating how to run the model +4. **Configuration Details**: Important configuration parameters and their meanings + +### Example + +For reference, see the [LongCat recipe example](https://github.com/vllm-project/recipes/pull/179) which demonstrates the expected format and structure. + +### Recipe Location + +Create your recipe file in the appropriate directory structure: +- For organization-specific models: `OrganizationName/ModelName.md` +- For general models: `ModelName.md` + +The recipe should be a Markdown file that provides clear, reproducible instructions for users to get started with your model. + +## Summary + +Adding a new model to vLLM-Omni involves: + +1. **Create model directory structure** with stage implementations +2. **Implement unified model class** that orchestrates stages +3. **Register model** in `registry.py` +4. **Create stage configuration** YAML file +5. **Implement stage input processors** for stage transitions +6. **Write tests** to verify functionality +7. **Add model recipe** to the [vllm-project/recipes](https://github.com/vllm-project/recipes) repository (see [Adding a Model Recipe](#adding-a-model-recipe) section) + +### Qwen3-Omni Reference Files + +For a complete reference implementation, see: + +- **Main model**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py` +- **Thinker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py` +- **Talker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py` +- **Code2Wav**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py` +- **Stage config**: `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml` +- **Input processors**: `vllm_omni/model_executor/stage_input_processors/qwen3_omni.py` +- **Registry**: `vllm_omni/model_executor/models/registry.py` +- **Testing**: `vllm_omni/tests/e2e/offline_inference/test_qwen3_omni.py` + +For more information, see: +- [Architecture Overview](../../design/architecture_overview.md) +- [Supported Models](../../models/supported_models.md) +- [Stage Configuration Guide](../../configuration/stage_configs.md) diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md new file mode 100644 index 0000000000000000000000000000000000000000..a7df8c32297e5050320f61ad52cddf4b54583612 --- /dev/null +++ b/docs/contributing/profiling.md @@ -0,0 +1,149 @@ +# Profiling vLLM-Omni + +> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production. + +vLLM-Omni uses the PyTorch Profiler to analyze performance across both **multi-stage omni-modality models** and **diffusion models**. + +### 1. Set the Output Directory +Before running any script, set this environment variable. The system detects this and automatically saves traces here. + +```bash +export VLLM_TORCH_PROFILER_DIR=./profiles +``` + +### 2. Profiling Omni-Modality Models + +It is best to limit profiling to one iteration to keep trace files manageable. + +```bash +export VLLM_PROFILER_MAX_ITERS=1 +``` + +**Selective Stage Profiling** +The profiler is default to function across all stages. But It is highly recommended to profile specific stages by passing the stages list, preventing from producing too large trace files: +```python +# Profile all stages +omni_llm.start_profile() + +# Only profile Stage 1 +omni_llm.start_profile(stages=[1]) +``` + +```python +# Stage 0 (Thinker) and Stage 2 (Audio Decoder) for qwen omni +omni_llm.start_profile(stages=[0, 2]) +``` + +**Python Usage**: Wrap your generation logic with `start_profile()` and `stop_profile()`. + +```python +from vllm_omni import omni_llm + +profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + +# 1. Start profiling if enabled +if profiler_enabled: + omni_llm.start_profile(stages=[0]) + +# Initialize generator +omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator) + +total_requests = len(prompts) +processed_count = 0 + +# Main Processing Loop +for stage_outputs in omni_generator: + + # ... [Output processing logic for text/audio would go here] ... + + # Update count to track when to stop profiling + processed_count += len(stage_outputs.request_output) + + # 2. Check if all requests are done to stop the profiler safely + if profiler_enabled and processed_count >= total_requests: + print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") + + # Stop the profiler while workers are still active + omni_llm.stop_profile() + + # Wait for traces to flush to disk + print("[Info] Waiting 30s for workers to write trace files to disk...") + time.sleep(30) + print("[Info] Trace export wait time finished.") + +omni_llm.close() +``` + + +**Examples**: + +1. **Qwen2.5-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py) + +2. **Qwen3-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py) + + +### 3. Profiling diffusion models + +Diffusion profiling is End-to-End, capturing encoding, denoising loops, and decoding. + +**CLI Usage:** +```python + +python image_to_video.py \ + --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --image qwen-bear.png \ + --prompt "A cat playing with yarn, smooth motion" \ + \ + # Minimize Spatial Dimensions (Optional but helpful): + # Drastically reduces memory usage so the profiler doesn't + # crash due to overhead, though for accurate performance + # tuning you often want target resolutions. + --height 48 \ + --width 64 \ + \ + # Minimize Temporal Dimension (Frames): + # Video models process 3D tensors (Time, Height, Width). + # Reducing frames to the absolute minimum (2) keeps the + # tensor size small, ensuring the trace file doesn't become + # multi-gigabytes in size. + --num_frames 2 \ + \ + # Minimize Iteration Loop (Steps): + # This is the most critical setting for profiling. + # Diffusion models run the same loop X times. + # Profiling 2 steps gives you the exact same performance + # data as 50 steps, but saves minutes of runtime and + # prevents the trace viewer from freezing. + --num_inference_steps 2 \ + \ + --guidance_scale 5.0 \ + --guidance_scale_high 6.0 \ + --boundary_ratio 0.875 \ + --flow_shift 12.0 \ + --fps 16 \ + --output i2v_output.mp4 + +``` + +**Examples**: + +1. **Qwen image edit**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) + +2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video) + +> **Note:** +As of now, asynchronous (online) profiling is not fully supported in vLLM-Omni. While start_profile() and stop_profile() methods exist, they are only reliable in offline inference scripts (e.g., the provided end2end.py examples). Do not use them in server-mode or streaming scenarios—traces may be incomplete or fail to flush. + +### 4. Analyzing Omni Traces + +Output files are saved to your configured ```VLLM_TORCH_PROFILER_DIR```. + +**Output** +**Chrome Trace** (```.json.gz```): Visual timeline of kernels and stages. Open in Perfetto UI. + +**Viewing Tools:** + +- [Perfetto](https://ui.perfetto.dev/)(recommended) +- ```chrome://tracing```(Chrome only) + +**Note**: vLLM-Omni reuses the PyTorch Profiler infrastructure from vLLM. See the official vLLM profiler documentation: [vLLM Profiling Guide](https://docs.vllm.ai/en/stable/contributing/profiling/) diff --git a/docs/design/architecture_overview.md b/docs/design/architecture_overview.md new file mode 100644 index 0000000000000000000000000000000000000000..6793895cd463305dc70823c04d79b159558513c2 --- /dev/null +++ b/docs/design/architecture_overview.md @@ -0,0 +1,195 @@ +# Architecture Overview + +This document outlines the architectural design for vLLM-Omni. + +

+ + + Omni-Modality Model Architecture + +

+ +# Goals + +The primary goal of the vLLM-Omni project is to build the fastest and easiest-to-use open-source Omni-Modality model inference & serving engine. vLLM-Omni extends the original vLLM, which was created to support large language models for text-based autoregressive (AR) generation tasks. vLLM-Omni is designed to support: + +* **Non-textual Output:** Enables the integration, efficient processing and output of various data types, including but not limited to, images, audio, and video, alongside text. +* **Non-Autoregressive Structure:** Support model structure beyond autoregressive, especially Diffusion Transformer (DiT), which is widely used in visual and audio generation. +* **Integration with vLLM Core:** Maintain compatibility and leverage existing vLLM key modules and optimizations where applicable. +* **Extensibility:** Design a modular and flexible architecture that can easily accommodate new modalities, model architectures, and output formats. + + +# Representative omni-modality models + +According to analysis for current popular open-source models, most of them have the combination of AR+DiT. Specifically, they can be further categorized into 3 types below: + +**DiT as a main structure, with AR as text encoder (e.g.: Qwen-Image)** + A powerful image generation foundation model capable of complex text rendering and precise image editing. + +

+ + + Qwen-Image + +

+ +**AR as a main structure, with DiT as multi-modal generator (e.g. BAGEL)** + A unified multimodal comprehension and generation model, with cot text output and visual generation. + +

+ + + Bagel + +

+ +**AR+DiT (e.g. Qwen-Omni)** + A natively end-to-end omni-modal LLM for multimodal inputs (text/image/audio/video...) and outputs (text/audio...). + +

+ + + Qwen-Omni + +

+ +# vLLM-Omni main architecture + +

+ + + vLLM-Omni Main Architecture + +

+ +## Key Components + +| Component | Description | +| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | +| **OmniRouter** | provide an intelligent router for Omni-modality requests dispatch | +| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni) and provide the OmniStage abstraction for different AR/DiT stages | +| **AR** | adapted for omni-modality models while inheriting efficient features from vLLM, such as cache management | +| **Diffusion** | natively implemented and optimized using acceleration components | +| **OmniConnector** | supports fully disaggregation based on E/P/D/G (Encoding/Processing/Decoding/Generation) disaggregation across stages | + +Disaggregated stages are managed through configuration, such as in the Qwen3-Omni example, where stages like Thinker, Talker, and Code2wav are defined as separate OmniStage instances with specific resources and input/output type. + +## Main features + +vLLM-Omni aims to be fast, flexible, and easy to use with the following features: + +### Performance and Acceleration + +The framework achieves high performance through several optimization techniques: + +* **Efficient AR Support:** Leverages efficient KV cache management inherited from vLLM. +* **Pipelined Execution:** Uses pipelined stage execution overlapping to ensure high throughput. +* **Full Disaggregation:** Relies on the OmniConnector and dynamic resource allocation across stages. +* **Diffusion Acceleration:** Includes integrated support for diffusion acceleration. This is managed by the acceleration layer, which handles: + * **Cache:** Includes DBCache, TeaCache and third-party integration(e.g., [cache-dit](https://github.com/vipshop/cache-dit)). + * **Parallelism:** Supports TP, CP, USP, and CFG. + * **Attention:** Provides an interface for third-party integration (e.g., FA3, SAGE, MindIE-SD). + * **Quantization:** Supports various quantization implementations including FP8 and AWQ. + * **FusedOps:** Allows for custom and third-party integration. + +### Flexibility and Usability + +vLLM-Omni is designed to be flexible and straightforward for users: + +* **Heterogeneous Pipeline Abstraction:** Manages complex model workflows effectively. +* **Hugging Face Integration:** Offers seamless integration with popular Hugging Face models. +* **Distributed Inference:** Supports tensor, pipeline, data, and expert parallelism. +* **Streaming Outputs:** Supports streaming outputs. +* **Unified API:** Provides a consistent and unified API interface compatible with vLLM. +* **OpenAI-compatible API Server:** Includes a FastAPI-based server for online serving that is compatible with the OpenAI API. + +# Interface design + +If you use vLLM, then you know how to use vLLM-Omni from Day 0: + +

+ + + vLLM-Omni interface design + +

+ +Taking **Qwen3-Omni** as an example: + +## Offline Inference +The **Omni** class provides a Python interface for offline batched inference. Users initialize the Omni class with a Hugging Face model name and use the generate method, passing inputs that include both text prompts and multi-modal data: + +``` +# Create an omni_lm with HF model name. +from vllm_omni.entrypoints.omni import Omni + +omni_lm = Omni(model="Qwen/Qwen3-Omni-30B-A3B-Instruct") + +# Example prompts. +om_inputs = {"prompt": prompt, + "multi_modal_data": { + "video": video_frames, + "audio": audio_signal, + }} + +# Generate texts and audio from the multi-modality inputs. +outputs = omni_lm.generate(om_inputs, sampling_params_list) +``` + +## Online Serving +Similar to vLLM, vLLM-Omni also provides a FastAPI-based server for online serving. Users can launch the server using the vllm serve command with the `--omni` flag: + +``` +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + +Users can send requests to the server using curl: + +``` +# prepare user content +user_content='[ + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "Why is this video funny?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + +# send the request +curl -sS -X POST http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @- < Talker). +Current connectors operate in D2H2D (device to host to device) mode. + +## Connector Choices + +| Use Case | Recommended Connector | Notes | +| :--- | :--- | :--- | +| Single node | SharedMemoryConnector | Auto-configured if no connector is specified. | +| Multi node (Mooncake) | MooncakeConnector | Requires Mooncake Master + metadata server. | +| Multi node (Yuanrong) | YuanrongConnector | Requires Yuanrong Datasystem + etcd. | + +## Core API + +The connector system is built around `OmniConnectorBase`. + +```python +class OmniConnectorBase(ABC): + @abstractmethod + def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[bool, int, Optional[dict]]: + """ + Store data. + Returns: (success, serialized_size, metadata) + """ + pass + + @abstractmethod + def get(self, from_stage: str, to_stage: str, get_key: str, metadata: Optional[dict] = None) -> Optional[tuple[Any, int]]: + """ + Retrieve data. + Args: metadata - transport-specific handles returned by put() (e.g., SHM name). + Returns: (object, serialized_size) + """ + pass +``` + +### Metadata Passing + +Some connectors (e.g., SharedMemoryConnector) generate transient resources during `put()`. +This `metadata` must be passed through the control plane so `get()` can locate the data. + +## Configuration Model + +Define connectors in runtime: + +```yaml +runtime: + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 +``` + +Wire stages to connectors: + +```yaml +stage_args: + - stage_id: 0 + output_connectors: + to_stage_1: connector_of_shared_memory + + - stage_id: 1 + input_connectors: + from_stage_0: connector_of_shared_memory +``` + +If a pipeline edge has no explicit connector, the system auto-creates a +SharedMemoryConnector for that edge. + +## Relationship with vLLM + +vLLM provides specialized distributed mechanisms for specific artifacts: + +- KV Transfer (`vllm.distributed.kv_transfer`): optimized for KV caches. +- EC Transfer (`vllm.distributed.ec_transfer`): optimized for encoder embeddings. +- Device Communicators (`vllm.distributed.device_communicators`): low-level primitives (NCCL, SHM). + +vllm-omni complements this with a generalized connector abstraction: + +1. Unifies transport via a single `put`/`get` API for any stage artifact. +2. Enables DAG-style pipelines across processes or nodes with per-edge transports. +3. Can wrap vLLM-specific transfers for KV paths while keeping a consistent interface. + +## Operational Notes + +- Fail-fast config validation: missing expected edges cause startup failures. +- Missing payloads halt stages: verify connector wiring and metadata propagation. + +## Future Roadmap: D2D Transport + +Current connectors use D2H2D paths. Future versions will introduce direct +device-to-device connectors (NCCL, UCX, IPC) to reduce latency for large +tensor payloads. diff --git a/docs/design/feature/omni_connectors/mooncake_connector.md b/docs/design/feature/omni_connectors/mooncake_connector.md new file mode 100644 index 0000000000000000000000000000000000000000..fee409ecf09cfe53a53220d730731d897493f1cd --- /dev/null +++ b/docs/design/feature/omni_connectors/mooncake_connector.md @@ -0,0 +1,74 @@ +# MooncakeConnector + +## When to Use + +Best for multi-node distributed inference using Mooncake. + +## Installation + +```bash +# For CUDA-enabled systems (recommended) +pip install mooncake-transfer-engine + +# For non-CUDA systems +pip install mooncake-transfer-engine-non-cuda +``` + +## Start Mooncake Master + +```bash +# If you use Mooncake SSD storage +mkdir -p ./mc_storage + +mooncake_master \ + --rpc_port=50051 \ + --enable_http_metadata_server=true \ + --http_metadata_server_host=0.0.0.0 \ + --http_metadata_server_port=8080 \ + --metrics_port=9003 \ + --root_fs_dir=./mc_storage/ \ + --cluster_id=mc-local-1 & +``` + +## Configuration + +Define the connector in runtime: + +```yaml +runtime: + connectors: + connector_of_mooncake: + name: MooncakeConnector + extra: + host: "127.0.0.1" + metadata_server: "http://:8080/metadata" + master: ":50051" + segment: 512000000 + localbuf: 64000000 + proto: "tcp" +``` + +Wire stages to the connector: + +```yaml +stage_args: + - stage_id: 0 + output_connectors: + to_stage_1: connector_of_mooncake + + - stage_id: 1 + input_connectors: + from_stage_0: connector_of_mooncake +``` + +Parameters: + +- host: local worker IP registered in the metadata server. +- metadata_server: metadata server URL for discovery and setup. +- master: Mooncake Master address. +- segment: global memory segment size in bytes. +- localbuf: local buffer size in bytes. +- proto: transport protocol ("tcp" or "rdma"). + +For more details, refer to the +[Mooncake repository](https://github.com/kvcache-ai/Mooncake). diff --git a/docs/design/feature/omni_connectors/shared_memory_connector.md b/docs/design/feature/omni_connectors/shared_memory_connector.md new file mode 100644 index 0000000000000000000000000000000000000000..eb91a889a4daa5af4bb59bddd1c97f1526502098 --- /dev/null +++ b/docs/design/feature/omni_connectors/shared_memory_connector.md @@ -0,0 +1,27 @@ +# SharedMemoryConnector + +## When to Use + +Best for single-node deployments where stages run on the same host. It is +auto-configured when no explicit connector is specified for an edge. + +## How It Works + +- Small payloads (< threshold): serialized and passed inline in metadata. +- Large payloads (>= threshold): stored in shared memory; the SHM name is + returned in metadata. + +## Configuration + +```yaml +runtime: + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 +``` + +## Notes + +- Auto-mode uses SharedMemoryConnector if no connector is declared for an edge. diff --git a/docs/design/feature/omni_connectors/yuanrong_connector.md b/docs/design/feature/omni_connectors/yuanrong_connector.md new file mode 100644 index 0000000000000000000000000000000000000000..88989d31aa4d8bd75f27b089e4ba728d6ecd9d5a --- /dev/null +++ b/docs/design/feature/omni_connectors/yuanrong_connector.md @@ -0,0 +1,101 @@ +# YuanrongConnector + +## When to Use + +Best for multi-node distributed inference using Yuanrong Datasystem. + +## Mechanism + +Uses Yuanrong Datasystem's distributed KV store (`datasystem.kv_client`). + +- Data Plane: TCP or RDMA for high-bandwidth transfer. +- Control Plane: Yuanrong Datasystem workers and etcd. +- Keying: deterministic keys based on `put_key` (often composed as `request_id:fromStage_toStage`). + +## Installation + +```bash +pip install openyuanrong-datasystem +``` + +## Start etcd + +```bash +# Download and install etcd (v3.5.12 or higher) +ETCD_VERSION="v3.5.12" +ETCD_ARCH="linux-arm64" +wget https://github.com/etcd-io/etcd/releases/download/${ETCD_VERSION}/etcd-${ETCD_VERSION}-${ETCD_ARCH}.tar.gz +tar -xvf etcd-${ETCD_VERSION}-${ETCD_ARCH}.tar.gz +cd etcd-${ETCD_VERSION}-${ETCD_ARCH} +sudo cp etcd etcdctl /usr/local/bin/ + +# Start etcd +etcd \ + --name etcd-single \ + --data-dir /tmp/etcd-data \ + --listen-client-urls http://0.0.0.0:2379 \ + --advertise-client-urls http://0.0.0.0:2379 \ + --listen-peer-urls http://0.0.0.0:2380 \ + --initial-advertise-peer-urls http://0.0.0.0:2380 \ + --initial-cluster etcd-single=http://0.0.0.0:2380 & + +# Verify etcd is running +etcdctl --endpoints "127.0.0.1:2379" put key "value" +etcdctl --endpoints "127.0.0.1:2379" get key +``` + +For production environments, refer to the +[official etcd clustering documentation](https://etcd.io/docs/current/op-guide/clustering/). + +## Start Datasystem Worker + +```bash +# Replace ${ETCD_IP} with etcd node IP, ${WORKER_IP} with local node IP +dscli start -w \ + --worker_address "${WORKER_IP}:31501" \ + --etcd_address "${ETCD_IP}:2379" \ + --shared_memory_size_mb 20480 +``` + +To stop the worker: + +```bash +dscli stop --worker_address "${WORKER_IP}:31501" +``` + +## Configuration + +Define the connector in runtime: + +```yaml +runtime: + connectors: + connector_of_yuanrong: + name: YuanrongConnector + extra: + host: "127.0.0.1" + port: 31501 + get_sub_timeout_ms: 1000 +``` + +Wire stages to the connector: + +```yaml +stage_args: + - stage_id: 0 + output_connectors: + to_stage_1: connector_of_yuanrong + + - stage_id: 1 + input_connectors: + from_stage_0: connector_of_yuanrong +``` + +Parameters: + +- host: datasystem worker host. +- port: datasystem worker port. +- get_sub_timeout_ms: get timeout in milliseconds (0 for no timeout). + +For more details, refer to the +[Yuanrong Datasystem repository](https://atomgit.com/openeuler/yuanrong-datasystem). diff --git a/docs/design/feature/ray_based_execution.md b/docs/design/feature/ray_based_execution.md new file mode 100644 index 0000000000000000000000000000000000000000..f69649d227ab6f8469f1b70d44b865ba52d08b45 --- /dev/null +++ b/docs/design/feature/ray_based_execution.md @@ -0,0 +1,59 @@ +# Distributed utils + +This directory (vllm_omni/distributed/ray_utils) contains utilities for distributed execution in vllm-omni, supporting both **Ray** and **Multiprocessing** backends. +## 1. Installation +```bash +pip install "ray[default]" +``` +## 2. Ray Utils + +The `ray_utils` module provides helper functions for managing Ray clusters and actors, which is essential for: +* **Multi-node deployment**: Running pipeline stages across different physical machines. +* **Resource management**: Efficient GPU/CPU allocation. + +### 2.1 Basic Usage + +To use the Ray backend, specify `worker_backend="ray"` when initializing the engine. + +**Command Line Example:** +```bash +vllm serve Qwen/Qwen2.5-Omni-7B \ + --omni \ + --port 8091 \ + --worker-backend ray \ + --ray-address auto +``` + +### 2.2 Cluster Setup + +**Step 1: Start Head Node** +Run this on your primary machine: +```bash +ray start --head --port=6399 +``` + +**Step 2: Connect Worker Nodes** +Run this on each worker machine: +```bash +ray start --address=:6399 +``` + +> **Tip**: For a complete cluster setup script, refer to the vLLM example: +> [run_cluster.sh](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/run_cluster.sh) + +### 2.3 Distributed Connector Support + +When running on Ray, the system automatically adapts its communication strategy: + +* **Cross-Node**: Recommended to use `MooncakeConnector` (requires separate configuration). +* **Same-Node**: Can still use `SharedMemoryConnector` for efficiency, or Ray's native object store (plasma). +* **SHM threshold default differs**: when `worker_backend="ray"`, the SharedMemoryConnector default threshold is set to `sys.maxsize`, which forces payloads to go inline (no SHM). Override `shm_threshold_bytes` in the connector config if you want SHM for Ray runs. + +### 2.4 Internal Helpers + +* **`initialize_ray_cluster`**: Connects to an existing Ray cluster or starts a local one. + +## 3. Troubleshooting + +* **Connection Issues**: Ensure the Ray head node is accessible and ports (default 6399 in this example) are open. +* **Version Mismatch**: Ensure all nodes run the same version of Ray and Python. diff --git a/docs/design/index.md b/docs/design/index.md new file mode 100644 index 0000000000000000000000000000000000000000..31420550fbd1791a2c618ac47fcb1a574bd9a0fc --- /dev/null +++ b/docs/design/index.md @@ -0,0 +1,18 @@ +# Design Documents + +This section contains design documents and architecture specifications for vLLM-Omni. + +## Architecture Documents + +- [Architecture Overview](architecture_overview.md) + +## Feature Design Documents + +- [Disaggregated Inference](feature/disaggregated_inference.md) +- [Ray-based Execution](feature/ray_based_execution.md) + +## Module Design Documents + +- [AR Module](module/ar_module.md) +- [DIT Module](module/dit_module.md) +- [Entrypoint Module](module/entrypoint_module.md) diff --git a/docs/design/module/ar_module.md b/docs/design/module/ar_module.md new file mode 100644 index 0000000000000000000000000000000000000000..c0f7cddf046392b548fbea0c6ca7e99396f56f4e --- /dev/null +++ b/docs/design/module/ar_module.md @@ -0,0 +1,414 @@ +# AutoRegressive (AR) Module + +## 1. Overview + +The AutoRegressive (AR) module in vLLM-Omni handles autoregressive generation stages, primarily used for text, chain-of-thought(COT), and audio latent tokens generation stages in multi-stage models like Qwen2.5-Omni, Qwen3-Omni, BAGEL, .etc. Unlike some representative non-autoregressive generation stages (e.g., Diffusion), AR stages generate tokens sequentially, one at a time, following the standard transformer decoder pattern. + +The AR module of vLLM-Omni extends vLLM's core components to support: + +- **Multimodal inputs/outputs**: Processing images, videos, and audio alongside text +- **Direct embedding transfer**: Passing pre-computed prompt embeddings between pipeline stages via serialized payloads +- **Additional information flow**: Carrying per-request metadata (tensors, lists) through the pipeline +- **Hidden state exposure**: Exposing per-request hidden representations for downstream stages +- **Basic generator support**: Support some basic heterogeneous architecture such as Convolution, LSTM, etc. + +As shown in the [end2end example](../../user_guide/examples/offline_inference/qwen3_omni.md), AR module can be widely applied across multiple stages, generating text tokens in thinker(AR), audio latent tokens in talker(AR) and audio wave in code2wav(Convolution). + +## 2. Relationship with vLLM + +The AR module builds upon vLLM main framework through inheritance, extending core classes while preserving compatibility with vLLM's scheduling, batching, KV cache management, and execution mechanisms. + +### Inheritance Hierarchy +- Scheduler + +```mermaid +classDiagram + class VLLMScheduler { + +schedule() SchedulerOutput + +update_from_output() EngineCoreOutputs + } + class OmniARScheduler { + +schedule() SchedulerOutput + } + class OmniGenerationScheduler { + +schedule() SchedulerOutput + +update_from_output() EngineCoreOutputs + } + VLLMScheduler <|-- OmniARScheduler + VLLMScheduler <|-- OmniGenerationScheduler +``` +- Worker + +```mermaid +classDiagram + class GPUWorker { + +init_device() + +model_runner + } + class GPUARWorker { + +init_device() + } + class GPUGenerationWorker { + +init_device() + } + GPUWorker <|-- GPUARWorker + GPUWorker <|-- GPUGenerationWorker +``` +- ModelRunner + +```mermaid +classDiagram + class GPUModelRunner { + +execute_model() + +sample_tokens() + } + class OmniGPUModelRunner { + +_update_states() + +_preprocess() + +_model_forward() + } + class GPUARModelRunner { + +execute_model() + +sample_tokens() + } + class GPUGenerationModelRunner { + +execute_model() + } + GPUModelRunner <|-- OmniGPUModelRunner + OmniGPUModelRunner <|-- GPUARModelRunner + OmniGPUModelRunner <|-- GPUGenerationModelRunner +``` +- InputProcessor/OutputProcessor + +```mermaid +classDiagram + class InputProcessor { + +process_inputs() EngineCoreRequest + } + class OmniInputProcessor { + +process_inputs() OmniEngineCoreRequest + } + + class VLLMOutputProcessor { + +process_outputs() OutputProcessorOutput + } + class MultimodalOutputProcessor { + +process_outputs() OutputProcessorOutput + +_route_and_normalize() + } + InputProcessor <|-- OmniInputProcessor + VLLMOutputProcessor <|-- MultimodalOutputProcessor +``` + +### Key Extensions + +- **Scheduler**: `OmniARScheduler` extends `vllm.v1.core.sched.scheduler.Scheduler` to enrich scheduled requests with omni-specific payloads +- **Worker**: `GPUARWorker` extends `vllm.v1.worker.gpu_worker.Worker` to initialize AR-specific model runners +- **ModelRunner**: `GPUARModelRunner` extends `OmniGPUModelRunner` → `vllm.v1.worker.gpu_model_runner.GPUModelRunner` to expose hidden states and handle multimodal outputs +- **InputProcessor**: `OmniInputProcessor` extends `vllm.v1.engine.input_processor.InputProcessor` to serialize prompt embeddings and additional information +- **OutputProcessor**: `MultimodalOutputProcessor` extends `vllm.v1.engine.output_processor.OutputProcessor` to route and accumulate multimodal outputs + +## 3. Scheduler Design + +The AR module provides two scheduler implementations: one for standard autoregressive generation and one for basic heterogeneous architectures. + +### Request Flow + +The following diagram illustrates the request flow through the AR module components: + +```mermaid +flowchart TD + A[OmniInputProcessor] -->|OmniEngineCoreRequest| B[OmniARScheduler] + B -->|schedule: OmniNewRequestData| C[GPUARWorker] + C -->|SchedulerOutput| D[GPUARModelRunner] + D -->|execute_model: None| E[Model Forward Pass] + E -->|hidden_states, logits| D + D -->|sample_tokens: OmniModelRunnerOutput| F[OmniARScheduler] + F -->|update_from_output| G[MultimodalOutputProcessor] + G -->|RequestOutput| H[Client/Downstream Stage] + + style A fill:#e1f5ff + style B fill:#fff4e1 + style C fill:#e8f5e9 + style D fill:#f3e5f5 + style G fill:#fce4ec +``` + +The flow follows vLLM's standard pattern: input processing → scheduling → worker execution → output processing, with omni-specific enrichments at each stage. + +### OmniARScheduler + +`OmniARScheduler` extends the base vLLM scheduler with minimal modifications, focusing on enriching scheduled requests with omni-specific payloads. + +#### Modified API: `schedule()` + +The scheduler wraps base `NewRequestData` entries with `OmniNewRequestData` to include prompt embeddings and additional information: + +```python +def schedule(self) -> SchedulerOutput: + scheduler_output = super().schedule() + # Rewrap base NewRequestData entries with OmniNewRequestData + new_list = [] + for nr in scheduler_output.scheduled_new_reqs: + request = self.requests.get(nr.req_id) + omni_nr = OmniNewRequestData( + req_id=nr.req_id, + prompt_token_ids=nr.prompt_token_ids, + # ... other base fields ... + prompt_embeds=getattr(request, "prompt_embeds", None), + additional_information=getattr(request, "additional_information", None), + ) + new_list.append(omni_nr) + scheduler_output.scheduled_new_reqs = new_list + return scheduler_output +``` + +The `update_from_output()` method remains unchanged, inheriting standard request lifecycle management from the base scheduler. + +### OmniGenerationScheduler + +`OmniGenerationScheduler` implements a fast-path scheduling strategy for basic heterogeneous architectures that process all input tokens in a single step. + +#### Modified API: `schedule()` + +Allocates all input tokens for a request at once (or 1 placeholder if zero), falling back to default scheduling if budget is insufficient: + +```python +def schedule(self) -> SchedulerOutput: + # Fast path: allocate all input tokens at once + while self.waiting and token_budget > 0: + request = self.waiting.peek_request() + required_tokens = max(getattr(request, "num_prompt_tokens", 0), 1) + if required_tokens > token_budget: + break # Fall back to default scheduling + # Allocate and schedule... +``` + +#### Modified API: `update_from_output()` + +Marks requests as finished immediately after one step, since generation models complete in a single forward pass: + +```python +def update_from_output(self, ...) -> dict[int, EngineCoreOutputs]: + # ... + # Diffusion request: completes in one step + request.status = RequestStatus.FINISHED_STOPPED + kv_transfer_params = self._free_request(request) + # ... +``` + +## 4. Worker and ModelRunner Design + +### GPUARWorker + +`GPUARWorker` initializes the AR-specific model runner while maintaining standard device initialization: + +```python +class GPUARWorker(GPUWorker): + def init_device(self): + # ... standard device initialization ... + self.model_runner = GPUARModelRunner(self.vllm_config, self.device) +``` + +### GPUARModelRunner + +`GPUARModelRunner` follows vLLM's two-phase execute/sample flow while exposing hidden states and multimodal outputs. + +#### Two-Phase Execution + +**Phase 1: `execute_model()`** - Runs forward pass and stores state: +- Computes logits from hidden states +- Stores `ExecuteModelState` with hidden states, logits, and multimodal outputs +- Returns `None` to defer sampling + +**Phase 2: `sample_tokens()`** - Samples tokens and builds output: +- Retrieves stored state from `execute_model()` +- Samples tokens using logits +- Extracts per-request hidden states and multimodal outputs +- Builds `OmniModelRunnerOutput` with `pooler_output` containing hidden states + +```python +def sample_tokens(self, grammar_output) -> OmniModelRunnerOutput: + # Retrieve stored state + hidden_states, multimodal_outputs = self.execute_model_state + + # Sample tokens + sampler_output = self._sample(logits, spec_decode_metadata) + + # Extract per-request hidden states + pooler_output = [] + for rid in req_ids: + hidden_slice = hidden_states_cpu[start:end] + payload = {"hidden": hidden_slice} + # Add multimodal outputs if present + pooler_output.append(payload) + + return OmniModelRunnerOutput( + pooler_output=pooler_output, + # ... other fields ... + ) +``` + +### GPUGenerationModelRunner + +`GPUGenerationModelRunner` implements a simplified single-phase execution for basic heterogeneous architectures: + +- No logits computation or token sampling +- Direct generation from forward pass in model implementation +- Returns outputs via `pooler_output` immediately after forward pass + +### OmniGPUModelRunner + +`OmniGPUModelRunner` provides shared functionality for both AR and Generation runners: + +#### Prompt Embeddings Overlay + +During prefill, overlays custom `prompt_embeds` from request state onto `inputs_embeds`: + +```python +def _collect_additional_information_for_prefill(self, num_scheduled_tokens_np): + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + # Overlay prompt_embeds for prefill portion + if pe_cpu is not None: + src = pe_cpu[num_computed_tokens:num_computed_tokens + overlay_len] + self.inputs_embeds[start_offset:start_offset + overlay_len].copy_(src) +``` + +#### Additional Information Processing + +Decodes and manages `additional_information` payloads: +- Decodes serialized payloads → CPU tensors in request state +- Passes runtime information to model via `runtime_additional_information` kwarg +- Processes model-provided updates via `postprocess()` hook +- Merges updates back into request state + +#### M-RoPE Position Initialization + +For multimodal models using M-RoPE (e.g., Qwen2-VL), computes position encodings from multimodal feature metadata (image grids, video grids, audio features). + +## 5. Input/Output Processing + +### Processing Pipeline + +The input/output processing pipeline handles serialization, routing, and accumulation of multimodal data: + +```mermaid +sequenceDiagram + participant Client + participant OmniInputProcessor + participant Scheduler + participant ModelRunner + participant MultimodalOutputProcessor + participant Client + + Client->>OmniInputProcessor: prompt + prompt_embeds + additional_info + OmniInputProcessor->>OmniInputProcessor: Serialize tensors to payloads + OmniInputProcessor->>Scheduler: OmniEngineCoreRequest (with payloads) + Scheduler->>ModelRunner: OmniNewRequestData (with payloads) + ModelRunner->>ModelRunner: Decode payloads → CPU tensors + ModelRunner->>ModelRunner: Overlay prompt_embeds on inputs_embeds + ModelRunner->>ModelRunner: Forward pass with runtime_additional_information + ModelRunner->>ModelRunner: Extract hidden states + multimodal outputs + ModelRunner->>MultimodalOutputProcessor: OmniModelRunnerOutput (pooler_output) + MultimodalOutputProcessor->>MultimodalOutputProcessor: Route by output_type + MultimodalOutputProcessor->>MultimodalOutputProcessor: Accumulate tensors in OmniRequestState + MultimodalOutputProcessor->>MultimodalOutputProcessor: Consolidate tensor lists + MultimodalOutputProcessor->>Client: RequestOutput (with multimodal_output) +``` + +### OmniInputProcessor + +`OmniInputProcessor` extends the base input processor to serialize prompt embeddings and additional information for inter-stage transfer. + +#### Payload Serialization + +Converts PyTorch tensors to serialized payloads: + +```python +def process_inputs(self, ...) -> OmniEngineCoreRequest: + # Serialize prompt_embeds + if "prompt_embeds" in decoder_inputs: + pe_cpu = decoder_inputs["prompt_embeds"].detach().to("cpu").contiguous() + prompt_embeds_payload = PromptEmbedsPayload( + data=pe_cpu.numpy().tobytes(), + shape=[seq_len, hidden_size], + dtype=dtype_str, + ) + + # Serialize additional_information + if "additional_information" in decoder_inputs: + entries = {} + for key, value in raw_info.items(): + if isinstance(value, torch.Tensor): + entry = AdditionalInformationEntry( + tensor_data=value.numpy().tobytes(), + tensor_shape=list(value.shape), + tensor_dtype=dtype_str, + ) + entries[key] = entry + additional_information_payload = AdditionalInformationPayload(entries=entries) + + return OmniEngineCoreRequest( + # ... standard fields ... + prompt_embeds=prompt_embeds_payload, + additional_information=additional_information_payload, + ) +``` + +### MultimodalOutputProcessor + +`MultimodalOutputProcessor` routes outputs by modality type and accumulates multimodal tensors. + +#### Output Routing + +Routes `EngineCoreOutput` by `output_type` attribute: +- `"text"`: Standard text generation path +- `"image"`, `"audio"`, `"latents"`: Extract from `pooling_output` or `multimodal_outputs` +- Fallback: Heuristic based on presence of `pooling_output` + +#### Tensor Accumulation + +`OmniRequestState` accumulates multimodal tensors across multiple steps: + +```python +def add_multimodal_tensor(self, payload, mm_type): + # Normalize payload to dict + incoming = {mm_type or "hidden": payload} + + # Accumulate: convert tensors to lists for deferred concatenation + if isinstance(v, torch.Tensor) and isinstance(existing, torch.Tensor): + self.mm_accumulated[k] = [existing, v] # List accumulation +``` + +Before final output, consolidates tensor lists via concatenation: + +```python +def _consolidate_multimodal_tensors(self): + for k, v in self.mm_accumulated.items(): + if isinstance(v, list) and isinstance(v[0], torch.Tensor): + self.mm_accumulated[k] = torch.cat(v, dim=0) # Concatenate +``` + +The consolidated tensors are attached to `RequestOutput.multimodal_output` for consumption by downstream stages or clients. + +## 6. Summary + +The AR module of vLLM-Omni extends vLLM through strategic inheritance and minimal API modifications: + +### Key Design Patterns + +1. **Inheritance over composition**: Extends vLLM classes to preserve compatibility with existing scheduling, batching, and execution mechanisms +2. **Payload serialization**: Uses `PromptEmbedsPayload` and `AdditionalInformationPayload` for efficient inter-stage data transfer +3. **Two-phase execution**: Maintains vLLM's execute/sample separation for AR models while supporting single-phase execution for generation models +4. **Multimodal routing**: Routes outputs by `output_type` and accumulates tensors incrementally to support streaming + +### Differences from vLLM + +- **Payload support**: Serialized prompt embeddings and additional information enable direct transfer between pipeline stages +- **Multimodal handling**: Extended input/output processors support images, audio, and other modalities alongside text +- **Hidden state exposure**: AR model runners expose per-request hidden states via `pooler_output` for downstream consumption +- **Generation scheduler**: Fast-path scheduling for basic heterogeneous architectures that complete in one step + +The AR module seamlessly integrates with vLLM's existing infrastructure while adding the necessary extensions for multi-stage, multimodal generation pipelines. diff --git a/docs/design/module/dit_module.md b/docs/design/module/dit_module.md new file mode 100644 index 0000000000000000000000000000000000000000..4bc3f2b6ceeb61595e8d37479e4b4033cdb108ba --- /dev/null +++ b/docs/design/module/dit_module.md @@ -0,0 +1,906 @@ +--- +toc_depth: 4 +--- + +# Diffusion Module Architecture Design + +The vLLM-Omni diffusion module (`vllm_omni/diffusion`) is a high-performance inference engine for diffusion models, designed with a modular architecture that separates concerns across multiple components. It provides efficient execution for non-autoregressive generation tasks such as image and video generation. + +This document describes the architecture design of the diffusion module, including the diffusion engine, scheduler, worker, diffusion pipeline, and acceleration components. + +

+ vLLM-Omni Diffusion Module Components +

+

+ Main Components of the Diffusion Module +

+ + +**Table of Content:** + +- [Architecture Overview](#architecture-overview) +- [Diffusion Engine](#1-diffusion-engine) +- [Scheduler](#2-scheduler) +- [Worker](#3-worker) +- [Diffusion Pipeline](#4-diffusion-pipeline) +- [Acceleration Components](#5-acceleration-components) + - [Attention Backends](#51-attention-backends) + - [Parallel Attention](#52-parallel-attention) + - [Cache Backends](#53-cache-backends) + - [Parallel Strategies](#54-parallel-strategies) +- [Data Flow](#6-data-flow) + +--- + +## Architecture Overview + +The diffusion module follows a **multi-process, distributed architecture** with clear separation of concerns: + +

+ vLLM-Omni Diffusion Module Architecture +

+

+ Diffusion Architecture Overview +

+ + +--- + +## 1. Diffusion Engine + +**Location**: `vllm_omni/diffusion/diffusion_engine.py` + +### Responsibilities + +The `DiffusionEngine` is the **orchestrator** of the diffusion inference system. It manages the lifecycle of worker processes and coordinates the execution flow. + +### Key Components + +#### 1.1 Initialization + +```python +class DiffusionEngine: + def __init__(self, od_config: OmniDiffusionConfig): + self.od_config = od_config + self.post_process_func = get_diffusion_post_process_func(od_config) + self.pre_process_func = get_diffusion_pre_process_func(od_config) + self._processes: list[mp.Process] = [] + self._make_client() +``` + +**Key Features**: + +- **Pre/Post Processing**: Registers model-specific pre-processing and post-processing functions via registry pattern + +- **Worker Management**: Launches and manages multiple worker processes (one per GPU) + +- **Process Isolation**: Uses multiprocessing for true parallelism + +#### 1.2 Worker Launch Process + +The engine launches workers using a **spawn** method: + +```python +def _launch_workers(self, broadcast_handle): + # Creates one process per GPU + for i in range(num_gpus): + process = mp.Process( + target=worker_proc.worker_main, + args=(i, od_config, writer, broadcast_handle), + name=f"DiffusionWorker-{i}", + ) + process.start() +``` + +**Design Decisions**: + +- **Spawn Method**: Ensures clean state for each worker (no shared memory issues) + +- **Pipe Communication**: Uses `mp.Pipe` for initialization handshake + +- **Device Selection**: Each worker is assigned a specific GPU (`cuda:{rank}`) + +#### 1.3 Request Processing Flow + +```python +def step(self, requests: list[OmniDiffusionRequest]): + # 1. Pre-process requests + requests = self.pre_process_func(requests) + + # 2. Send to scheduler and wait for response + output = self.add_req_and_wait_for_response(requests) + + # 3. Post-process results + result = self.post_process_func(output.output) + return result +``` + +**Flow**: + +1. **Pre-processing**: Applies model-specific transformations + +2. **Scheduling**: Delegates to scheduler for distribution + +3. **Post-processing**: Converts raw outputs to final format (e.g., PIL images) + +--- + +## 2. Scheduler + +**Location**: `vllm_omni/diffusion/scheduler.py` + +### Architecture + +The `Scheduler` is implemented as a **Singleton** pattern to ensure a single coordination point across the system, i.e., only one scheduler instance exists for coordination. + +### Key Components + +#### 2.1 Message Queue System + +```python +class Scheduler: + def initialize(self, od_config: OmniDiffusionConfig): + # Broadcast queue: scheduler -> all workers + self.mq = MessageQueue( + n_reader=od_config.num_gpus, + n_local_reader=od_config.num_gpus, + local_reader_ranks=list(range(od_config.num_gpus)), + ) + + # Result queue: rank 0 worker -> scheduler + self.result_mq = None # Initialized later +``` + +**Communication Pattern**: + +- **Broadcast Queue**: One-to-many communication (scheduler → all workers) + +- **Result Queue**: One-to-one communication (rank 0 → scheduler) + +- **Shared Memory**: Uses `MessageQueue` (ZMQ-based) for efficient IPC + +#### 2.2 Request Distribution + +```python +def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: + # Broadcast request to all workers + self.mq.enqueue(requests) + + # Wait for result from Rank 0 + output = self.result_mq.dequeue() + return output +``` + +**Design Features**: + +- **Broadcast Model**: All workers receive the same request (for tensor parallelism) + +- **Single Response**: Only rank 0 sends results back (avoids duplicate outputs) + +- **Synchronous**: Blocks until result is received (can be made async) + +#### 2.3 Singleton Pattern + +```python +class Scheduler: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + +# Global singleton instance +scheduler = Scheduler() +``` + +**Benefits**: + +- **Single Point of Control**: Ensures consistent state + +- **Easy Access**: Global `scheduler` instance accessible everywhere + +- **Resource Management**: Centralized queue management + +--- + +## 3. Worker + +**Location**: `vllm_omni/diffusion/worker/gpu_worker.py` + +### Architecture + +Workers are **independent processes** that execute the actual model inference. Each worker runs on a dedicated GPU and participates in distributed inference. + +### Key Components + +#### 3.1 Worker Process Structure + +```python +class WorkerProc: + def __init__(self, od_config, gpu_id, broadcast_handle): + # Initialize ZMQ context for IPC + self.context = zmq.Context(io_threads=2) + + # Connect to broadcast queue (receive requests) + self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) + + # Create result queue (only rank 0) + if gpu_id == 0: + self.result_mq = MessageQueue(n_reader=1, ...) + + # Initialize GPU worker + self.worker = GPUWorker(local_rank=gpu_id, rank=gpu_id, od_config=od_config) +``` + +**Initialization Steps**: + +1. **IPC Setup**: Creates ZMQ context and message queues + +2. **Distributed Environment Setup**: Initializes PyTorch distributed communication + + - For CUDA GPUs: Uses NCCL (fast GPU communication) + + - For NPU: Uses HCCL (Huawei Collective Communications Library) + + - For other devices: Uses appropriate backend (GLOO, MCCL, etc.) + +3. **Model Loading**: Loads diffusion pipeline on assigned GPU + +4. **Cache Setup**: Enables cache backend if configured. + +#### 3.2 GPU Worker + +```python +class GPUWorker: + def init_device_and_model(self): + # Set distributed environment variables + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + # Initialize PyTorch distributed + init_distributed_environment(world_size, rank) + parallel_config = self.od_config.parallel_config + initialize_model_parallel( + data_parallel_size=parallel_config.data_parallel_size, + cfg_parallel_size=parallel_config.cfg_parallel_size, + sequence_parallel_size=parallel_config.sequence_parallel_size, + tensor_parallel_size=parallel_config.tensor_parallel_size, + pipeline_parallel_size=parallel_config.pipeline_parallel_size, + ) + + # Load model + model_loader = DiffusersPipelineLoader(load_config) + self.pipeline = model_loader.load_model(od_config, load_device=f"cuda:{rank}") + + # Setup cache backend + from vllm_omni.diffusion.cache.selector import get_cache_backend + self.cache_backend = get_cache_backend(od_config.cache_backend, od_config.cache_config) + + if self.cache_backend is not None: + self.cache_backend.enable(self.pipeline) +``` + +**Key Features**: + +- **Tensor Parallelism**: Supports multi-GPU tensor parallelism via PyTorch distributed + +- **Model Loading**: Uses `DiffusersPipelineLoader` for efficient weight loading + +- **Cache Integration**: Enables cache backends (TeaCache, cache-dit, etc.) transparently + +#### 3.3 Worker Busy Loop + +```python +def worker_busy_loop(self): + while self._running: + # 1. Receive unified message (generation request, RPC request, or shutdown) + msg = self.recv_message() + + # 2. Route message based on type + if isinstance(msg, dict) and msg.get("type") == "rpc": + # Handle RPC request + result, should_reply = self.execute_rpc(msg) + if should_reply: + self.return_result(result) + + elif isinstance(msg, dict) and msg.get("type") == "shutdown": + # Handle shutdown message + self._running = False + + else: + # Handle generation request (OmniDiffusionRequest list) + output = self.worker.execute_model(msg, self.od_config) + self.return_result(output) +``` + +**Execution Flow**: + +1. **Receive**: Dequeues unified messages from shared memory queue + +2. **Route**: Handles different message types (generation, RPC, shutdown) + +3. **Execute**: Runs forward pass through pipeline for generation requests + +4. **Respond**: Sends results back (rank 0 for generation, specified rank for RPC) + +#### 3.4 Model Execution + +```python +@torch.inference_mode() +def execute_model(self, reqs: list[OmniDiffusionRequest], od_config): + req = reqs[0] # TODO: support batching + + # Refresh cache backend if enabled + if self.cache_backend is not None and self.cache_backend.is_enabled(): + self.cache_backend.refresh(self.pipeline, req.num_inference_steps) + + # Set forward context for parallelism + with set_forward_context( + vllm_config=self.vllm_config, + omni_diffusion_config=self.od_config + ): + output = self.pipeline.forward(req) + return output +``` + +The model execution leverages multiple parallelism strategies that are transparently applied during the forward pass. The `set_forward_context()` context manager makes parallel group information available throughout the forward pass: + +```python +# Inside transformer layers, parallel groups are accessed via: +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sp_group, get_dp_group, get_cfg_group, get_pp_group +) +``` + +**Optimizations**: + +- **Cache Refresh**: Clears cache state before each generation for clean state + +- **Context Management**: Forward context ensures parallel groups are available during execution + +- **Single Request**: Currently processes one request at a time (batching TODO) + +--- + +## 4. Diffusion Pipeline + +**Location**: `vllm_omni/diffusion/models/*/pipeline_*.py` + +The pipeline is the **model-specific implementation** that orchestrates the diffusion process. Different models (QwenImage, Wan2.2, Z-Image) have their own pipeline implementations. + +Most pipeline implementation are referred from `diffusers`. The multi-step diffusion loop is usually the most time-consuming part during the overall inference process, which is defined by the `diffuse` function in the pipeline class. An example is as follows: + +```python +def diffuse(self, ...): + for i, t in enumerate(timesteps): + # Forward pass for positive prompt + transformer_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": prompt_embeds, + } + noise_pred = self.transformer(**transformer_kwargs)[0] + + # Forward pass for negative prompt (CFG) + if do_true_cfg: + neg_transformer_kwargs = {...} + neg_transformer_kwargs["cache_branch"] = "negative" + neg_noise_pred = self.transformer(**neg_transformer_kwargs)[0] + + # Combine predictions + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents)[0] + + return latents +``` + +**Key Features**: + +- **CFG Support**: Handles classifier-free guidance with separate forward passes + +- **Cache Branching**: Uses `cache_branch` parameter for cache-aware execution + +- **True CFG**: Implements advanced CFG with norm preservation + +To learn more about the diffusion pipeline and how to add a new diffusion pipeline, please view [Adding Diffusion Model](https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/model/adding_diffusion_model) + +--- + +## 5. Acceleration Components + +### 5.1 Attention Backends + +**Location**: `vllm_omni/diffusion/attention/` + +#### Architecture + +The attention system uses a **backend selector pattern** that automatically chooses the optimal attention implementation based on hardware and model configuration. + +#### Backend Selection + +**Location**: `vllm_omni/diffusion/attention/selector.py` + +```python +class Attention(nn.Module): + def __init__(self, num_heads, head_size, causal, softmax_scale, ...): + # Auto-select backend + self.attn_backend = get_attn_backend(-1) + self.attn_impl_cls = self.attn_backend.get_impl_cls() + self.attention = self.attn_impl_cls(...) +``` + +**Available Backends**: + +- **FlashAttention**: Optimized CUDA kernel (FA2/FA3) - memory efficient via tiling + +- **SDPA**: PyTorch's scaled dot-product attention - default, cross-platform + +- **SageAttention**: Sparse attention implementation from SageAttention library + +- **AscendAttention**: NPU-optimized attention for Ascend hardware + +These backends provide the **kernel implementations** for attention computation. For attention-level sequence parallelism strategies (Ring Attention, Ulysses), see [Parallel Attention](#52-parallel-attention). + +#### Backend Selection Mechanism + +```python +def get_attn_backend(head_size: int) -> type[AttentionBackend]: + # Check environment variable + backend_name = os.environ.get("DIFFUSION_ATTENTION_BACKEND") + + if backend_name: + return load_backend(backend_name.upper()) + + # Default to SDPA + return SDPABackend +``` + +**Selection Priority**: + +1. **Environment Variable**: `DIFFUSION_ATTENTION_BACKEND` for manual override + + - Valid values: `FLASH_ATTN`, `TORCH_SDPA`, `SAGE_ATTN`, `ASCEND` + + - Example: `export DIFFUSION_ATTENTION_BACKEND=SAGE_ATTN` + +2. **Automatic Fallback**: Falls back to SDPA if selected backend unavailable + +3. **Hardware Detection**: Can select based on device type (NPU, CUDA, etc.) + +**Backend Availability**: + +- **SDPA**: Always available (PyTorch built-in) + +- **FlashAttention**: Requires `flash-attn` package installed + +- **SageAttention**: Requires `sage-attention` package (from THU-ML GitHub) + +- **AscendAttention**: Only available on Ascend NPU hardware + +#### Attention Backend Registry + +**Location**: `vllm_omni/diffusion/attention/selector.py` + +The attention system uses a **registry pattern** to manage and dynamically load attention backends. This allows for easy extension and runtime selection of backends. + + +**Registry Structure**: + +```python +# Registry mapping backend names to their module paths and class names +_BACKEND_CONFIG = { + "FLASH_ATTN": { + "module": "vllm_omni.diffusion.attention.backends.flash_attn", + "class": "FlashAttentionBackend", + }, + "TORCH_SDPA": { + "module": "vllm_omni.diffusion.attention.backends.sdpa", + "class": "SDPABackend", + }, + "SAGE_ATTN": { + "module": "vllm_omni.diffusion.attention.backends.sage_attn", + "class": "SageAttentionBackend", + }, + "ASCEND": { + "module": "vllm_omni.diffusion.attention.backends.ascend_attn", + "class": "AscendAttentionBackend", + }, +} +``` + +#### Attention Backend Integration + +The `Attention` layer integrates backends through a unified interface. Here's how **FlashAttentionBackend** is integrated as an example: + +```python +# attention/backends/flash_attn.py + +class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 96, 128, 192, 256] # FlashAttention supports these head sizes + + +class FlashAttentionImpl(AttentionImpl): + def __init__(self, num_heads, head_size, softmax_scale, causal, ...): + self.num_heads = num_heads + self.causal = causal + self.softmax_scale = softmax_scale + + def forward(self, query, key, value, attn_metadata=None): + # Call FlashAttention kernel + out = flash_attn_func( + query, key, value, + causal=self.causal, + softmax_scale=self.softmax_scale, + ) + return out +``` + +--- + +### 5.2 Parallel Attention + +**Location**: `vllm_omni/diffusion/attention/parallel/` + +#### Architecture + +Parallel attention strategies implement **Sequence Parallelism (SP) at the attention layer level**. These strategies distribute attention computation across multiple GPUs by splitting the sequence dimension, using different communication patterns. They work **on top of** AttentionBackend implementations (FlashAttention, SDPA, etc.), handling the parallelization/communication while the backends handle the actual attention computation. + +**Key Distinction**: Unlike AttentionBackend (which provides kernel implementations), ParallelAttentionStrategy provides communication patterns for multi-GPU attention parallelism. These strategies implement the `ParallelAttentionStrategy` interface and use AttentionBackend implementations internally. + +Both Ring Attention and Ulysses are forms of Sequence Parallelism (SP) that: + +- Split the sequence dimension across GPUs + +- Contribute to `sequence_parallel_size` (via `ring_degree` and `ulysses_degree`) + +- Work at the attention layer level (not model/pipeline level) + +#### Ulysses Sequence Parallelism (USP) + +**Location**: `vllm_omni/diffusion/attention/parallel/ulysses.py` + +USP is a sequence-parallel attention strategy that splits attention computation across multiple GPUs by distributing both the sequence dimension and attention heads. It uses **all-to-all communication** to efficiently parallelize attention for very long sequences. Specifically, it uses **all-to-all** collective operations to redistribute Q/K/V tensors before attention computation and gather results afterward. + +Ulysses splits attention computation in two dimensions: + +1. **Sequence Dimension**: Splits the sequence length across GPUs + +2. **Head Dimension**: Splits attention heads across GPUs + +**Configuration**: `ulysses_degree` contributes to `sequence_parallel_size` + +#### Ring Sequence Parallelism + +**Location**: `vllm_omni/diffusion/attention/parallel/ring.py` + +Ring Attention is a **parallel attention strategy** that implements sequence parallelism using ring-based point-to-point (P2P) communication. Unlike attention backends that provide the attention kernel implementation, Ring Attention is a **communication pattern** that works on top of attention backends (FlashAttention or SDPA). + +Ring Attention splits sequence dimension across GPUs in a ring topology, implemented via the `ParallelAttentionStrategy` interface, instead of `AttentionBackend`. P2P ring communication is applied to circulate Key/Value blocks across GPUs. Internally, `ring_flash_attn_func` or `ring_pytorch_attn_func` is used depending on available backends. + +**Architecture**: +```python +class RingParallelAttention: + """Ring sequence-parallel strategy.""" + + def run_attention(self, query, key, value, attn_metadata, ...): + # Selects underlying attention kernel (FlashAttention or SDPA) + if backend_pref == "sdpa": + return ring_pytorch_attn_func(...) # Uses SDPA kernel + else: + return ring_flash_attn_func(...) # Uses FlashAttention kernel +``` + +**Integration**: + +- Ring Attention is activated when `ring_degree > 1` in parallel config + +- It's selected by `build_parallel_attention_strategy()` in the attention layer + +- The `Attention` layer routes to `_run_ring_attention()` when Ring is enabled + +- Works alongside attention backends: Ring handles communication, backends handle computation + +**Configuration**: `ring_degree` contributes to `sequence_parallel_size` + +#### Relationship with AttentionBackend + +Parallel attention strategies (Ring, Ulysses) work **on top of** AttentionBackend implementations: + +- They use AttentionBackend for the actual attention computation (FlashAttention, SDPA, etc.) + +- They handle the multi-GPU communication/parallelization layer + +- They implement `ParallelAttentionStrategy` interface (not `AttentionBackend`) + +For general parallelism strategies (Data Parallelism, Tensor Parallelism, Pipeline Parallelism), see [Parallel Strategies](#54-parallel-strategies). + +--- + +### 5.3 Cache Backends + +**Location**: `vllm_omni/diffusion/cache/` + +#### Architecture + +Cache backends provide a **unified interface** for applying different caching strategies to accelerate diffusion inference. The system supports multiple backends (TeaCache, cache-dit) with a consistent API for enabling and refreshing cache state. + +#### Cache Backend Interface + +```python +class CacheBackend(ABC): + def __init__(self, config: DiffusionCacheConfig): + self.config = config + self.enabled = False + + @abstractmethod + def enable(self, pipeline: Any) -> None: + """Enable cache on the pipeline.""" + raise NotImplementedError + + @abstractmethod + def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache state for new generation.""" + raise NotImplementedError + + def is_enabled(self) -> bool: + """Check if cache is enabled.""" + return self.enabled +``` + +**Design Pattern**: + +- **Abstract Base Class**: Defines contract for all cache backends + +- **Pipeline-based**: Works with pipeline instances (not just transformers) + +- **State Management**: Provides refresh mechanism for clean state between generations + +#### Available Backends + +**1. TeaCache Backend** + +**Location**: `vllm_omni/diffusion/cache/teacache/backend.py` + +```python +class TeaCacheBackend(CacheBackend): + def enable(self, pipeline: Any): + # Extract transformer from pipeline + transformer = pipeline.transformer + transformer_type = transformer.__class__.__name__ + + # Create TeaCacheConfig from DiffusionCacheConfig + teacache_config = TeaCacheConfig( + transformer_type=transformer_type, + rel_l1_thresh=self.config.rel_l1_thresh, + coefficients=self.config.coefficients, + ) + + # Apply hooks to transformer + apply_teacache_hook(transformer, teacache_config) + self.enabled = True + + def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True): + transformer = pipeline.transformer + if hasattr(transformer, "_hook_registry"): + transformer._hook_registry.reset_hook(TeaCacheHook._HOOK_NAME) +``` + +**TeaCache Features**: + +- **Timestep-aware**: Caches based on timestep embedding similarity + +- **Adaptive**: Dynamically decides when to reuse cached computations + +- **CFG-aware**: Handles positive/negative branches separately + +- **Custom Hook System**: Uses a custom forward interception mechanism (via `HookRegistry`) that wraps the module's `forward` method, allowing transparent integration without modifying model code + +**2. Cache-DiT Backend** + +**Location**: `vllm_omni/diffusion/cache/cache_dit_backend.py` + +```python +class CacheDiTBackend(CacheBackend): + def enable(self, pipeline: Any): + # Uses cache-dit library for acceleration + # Supports DBCache, SCM (Step Computation Masking), TaylorSeer + # Works with single and dual-transformer architectures + ... + self.enabled = True + + def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True): + # Updates cache context with new num_inference_steps + ... +``` + +**Cache-DiT Features**: + +- **DBCache**: Dynamic block caching with configurable compute blocks + +- **SCM**: Step Computation Masking for additional speedup + +- **TaylorSeer**: Advanced calibration for cache accuracy + +- **Dual-transformer Support**: Handles models like Wan2.2 with two transformers + +#### Cache Backend Selector + +**Location**: `vllm_omni/diffusion/cache/selector.py` + +```python +def get_cache_backend( + cache_backend: str | None, + cache_config: dict | DiffusionCacheConfig +) -> CacheBackend | None: + """Get cache backend instance based on cache_backend string. + + Args: + cache_backend: Cache backend name ("cache_dit", "tea_cache", or None) + cache_config: Cache configuration (dict or DiffusionCacheConfig) + + Returns: + Cache backend instance or None if cache_backend is "none" + """ + if cache_backend is None or cache_backend == "none": + return None + + if isinstance(cache_config, dict): + cache_config = DiffusionCacheConfig.from_dict(cache_config) + + if cache_backend == "cache_dit": + return CacheDiTBackend(cache_config) + elif cache_backend == "tea_cache": + return TeaCacheBackend(cache_config) + else: + raise ValueError(f"Unsupported cache backend: {cache_backend}") +``` + +**Usage Flow**: + +1. **Selection**: `get_cache_backend()` returns appropriate backend instance + +2. **Enable**: `backend.enable(pipeline)` called during worker initialization + +3. **Refresh**: `backend.refresh(pipeline, num_inference_steps)` called before each generation + +4. **Check**: `backend.is_enabled()` verifies cache is active + +### 5.4 Parallel Strategies + +**Location**: `vllm_omni/diffusion/distributed/parallel_state.py` + +#### Parallelism Types + +The system supports multiple orthogonal parallelism strategies: + +**Sequence Parallelism (SP)** + +- **Purpose**: Split sequence dimension across GPUs + +- **Attention-level SP**: Ring Attention and Ulysses (USP) implement SP at the attention layer level + + - See [Parallel Attention](#52-parallel-attention) for details + + - Configuration: `ulysses_degree` × `ring_degree` = `sequence_parallel_size` + +- **Use Case**: Very long sequences (e.g., high-resolution images) + +**Data Parallelism (DP)** + +- **Purpose**: Replicate model across GPUs, split batch + +- **Use Case**: Batch processing, throughput optimization + +**Tensor Parallelism (TP)** (Experimental) + +- **Purpose**: Split model weights across GPUs + +- **Implementation**: Uses vLLM's tensor parallel groups + +- **Use Case**: Large models that don't fit on single GPU + +**CFG Parallelism** (under development) + +- **Purpose**: Parallelize Classifier-Free Guidance (positive/negative prompts) + +- **Infrastructure**: CFG parallel groups are initialized and available via `get_cfg_group()` + +#### Parallel Group Management + +```python +def initialize_model_parallel( + data_parallel_size: int = 1, + cfg_parallel_size: int = 1, + sequence_parallel_size: int | None = None, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + vae_parallel_size: int = 0, +): + # Generate orthogonal parallel groups + rank_generator = RankGenerator( + tensor_parallel_size, + sequence_parallel_size, + pipeline_parallel_size, + cfg_parallel_size, + data_parallel_size, + "tp-sp-pp-cfg-dp", + ) + + # Initialize each parallel group + _DP = init_model_parallel_group(rank_generator.get_ranks("dp"), ...) + _CFG = init_model_parallel_group(rank_generator.get_ranks("cfg"), ...) + _SP = init_model_parallel_group(rank_generator.get_ranks("sp"), ...) + _PP = init_model_parallel_group(rank_generator.get_ranks("pp"), ...) + _TP = init_model_parallel_group(rank_generator.get_ranks("tp"), ...) +``` + +**Rank Order**: `tp-sp-pp-cfg-dp` (tensor → sequence → pipeline → cfg → data) + +**Note**: For attention-level Sequence Parallelism implementations (Ring Attention and Ulysses), see [Parallel Attention](#52-parallel-attention). This section covers higher-level parallelism strategies. + + +--- + +## 6. Data Flow + +### Complete Request Flow + +

+ vLLM-Omni Diffusion Module Components +

+

+ End-to-end Data Flow in the vLLM-Omni Diffusion Module +

+ + +``` +1. User Request + └─> OmniDiffusion.generate(prompt) + └─> Prepare OmniDiffusionRequest + └─> DiffusionEngine.step(requests) + +2. Pre-processing + └─> pre_process_func(requests) + └─> Model-specific transformations + +3. Scheduling + └─> scheduler.add_req(requests) + └─> Broadcast via MessageQueue to all workers + +4. Worker Execution + └─> WorkerProc.worker_busy_loop() + └─> GPUWorker.execute_model(reqs) + └─> Pipeline.forward(req) + ├─> encode_prompt() + ├─> prepare_latents() + ├─> diffuse() [loop] + │ ├─> transformer.forward() [with cache backend hooks] + │ └─> scheduler.step() + └─> vae.decode() + +5. Result Collection + └─> Rank 0 sends DiffusionOutput via result queue + └─> Scheduler receives and returns + +6. Post-processing + └─> post_process_func(output) + └─> Convert to PIL images / final format +``` + +--- diff --git a/docs/design/module/entrypoint_module.md b/docs/design/module/entrypoint_module.md new file mode 100644 index 0000000000000000000000000000000000000000..7a26fbb7f05b56a39264281e5dc73903e9edc47c --- /dev/null +++ b/docs/design/module/entrypoint_module.md @@ -0,0 +1 @@ +Architecture design of the entrypoint (update soon) diff --git a/docs/examples/README.md b/docs/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4594d2bed2f7d24f3ddcdfbaab3c7d7615f4f964 --- /dev/null +++ b/docs/examples/README.md @@ -0,0 +1,6 @@ +# Examples + +vLLM-Omni's examples are split into two categories: + +- If you are using vLLM-Omni from within Python code, see the *Offline Inference* section. +- If you are using vLLM-Omni from an HTTP application or client, see the *Online Serving* section. diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md new file mode 100644 index 0000000000000000000000000000000000000000..41aa48c173594f9019b2d38ba50e5b127f80126b --- /dev/null +++ b/docs/features/sleep_mode.md @@ -0,0 +1,39 @@ +# Sleep Mode + +vLLM-Omni’s **Sleep Mode** allows you to temporarily release most GPU memory used by a model—such as model weights and key-value (KV) caches (for autoregressive models)—**without stopping the server or unloading the Docker container**. + +This feature is inherited from [vLLM’s Sleep Mode](https://blog.vllm.ai/2025/10/26/sleep-mode.html), which provides zero-reload model switching for multi-model serving. + +It is especially useful in **RLHF**, **training**, or **cost-saving scenarios**, where GPU resources must be freed between inference workloads. + +--- + +## Omni Model + +Omni model inherit the feature from vLLM' Sleep Mode + +This means: + +- Support both Level 1 and Level 2 sleep, allow to release and reset both model weights and KV Cache + +## Diffusion Model Extension + +We added Sleep Mode support for **diffusion models**, which previously lacked this functionality. +In diffusion pipelines, this currently only offloads **model weight memory**, as these models typically do not use KV caches. + +This means: + +- Diffusion models can now enter Level 1 sleep. +- Pipeline states (e.g., noise schedulers, buffers) remain intact after waking. +- Useful for releasing VRAM between image generation or training cycles. + +--- + +## Enable sleep mode +To enable sleep mode, set the `enable_sleep_mode` in `engine_args` to `True` + + +Example: +```python +omni = Omni(model=...,enable_sleep_mode=True) +``` diff --git a/docs/getting_started/installation/.nav.yml b/docs/getting_started/installation/.nav.yml new file mode 100644 index 0000000000000000000000000000000000000000..0fcee9a008ca95f6fa4134f6590d7e856be4521f --- /dev/null +++ b/docs/getting_started/installation/.nav.yml @@ -0,0 +1,4 @@ +nav: + - README.md + - gpu.md + - npu.md diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..20a737dca2d7d927d1dfc8948ccab2e010fcf8a3 --- /dev/null +++ b/docs/getting_started/installation/README.md @@ -0,0 +1,8 @@ +# Installation + +vLLM-Omni supports the following hardware platforms: + +- [GPU](gpu.md) + - [NVIDIA CUDA](gpu.md) + - [AMD ROCm](gpu.md) +- [NPU](npu.md) diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md new file mode 100644 index 0000000000000000000000000000000000000000..73c974280aefe8136fe457d862208b92160f5439 --- /dev/null +++ b/docs/getting_started/installation/gpu.md @@ -0,0 +1,66 @@ +# GPU + +vLLM-Omni is a Python library that supports the following GPU variants. The library itself mainly contains python implementations for framework and models. + +## Requirements + +- OS: Linux +- Python: 3.12 + +!!! note + vLLM-Omni is currently not natively supported on Windows. + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:requirements" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:requirements" + +## Set up using Python + +### Create a new Python environment + +--8<-- "docs/getting_started/installation/python_env_setup.inc.md" + +### Pre-built wheels + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:pre-built-wheels" + + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:pre-built-wheels" + +[](){ #build-from-source } + +### Build wheel from source + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:build-wheel-from-source" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:build-wheel-from-source" + +## Set up using Docker + +### Pre-built images + +=== "NVIDIA CUDA" + + --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:pre-built-images" + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:pre-built-images" + +### Build your own docker image + +=== "AMD ROCm" + + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:build-docker" diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md new file mode 100644 index 0000000000000000000000000000000000000000..3996a02437645009ed42829019bbda78549d1fad --- /dev/null +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -0,0 +1,103 @@ +# --8<-- [start:requirements] + +- GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) + +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] + +vLLM-Omni depends vLLM. So please follow instructions below mainly for vLLM. + +!!! note + PyTorch installed via `conda` will statically link `NCCL` library, which can cause issues when vLLM tries to use `NCCL`. See for more details. + +In order to be performant, vLLM has to compile many cuda kernels. The compilation unfortunately introduces binary incompatibility with other CUDA versions and PyTorch versions, even for the same PyTorch version with different building configurations. + +Therefore, it is recommended to install vLLM and vLLM-Omni with a **fresh new** environment. If either you have a different CUDA version or you want to use an existing PyTorch installation, you need to build vLLM from source. See [build-from-source-vllm](https://docs.vllm.ai/en/stable/getting_started/installation/gpu/#build-wheel-from-source) for more details. + +# --8<-- [start:pre-built-wheels] + +#### Installation of vLLM +Note: Pre-built wheels are currently only available for vLLM-Omni 0.11.0rc1, 0.12.0rc1, 0.14.0rc1, 0.14.0. For the latest version, please [build from source](https://docs.vllm.ai/projects/vllm-omni/en/latest/getting_started/installation/gpu/#build-wheel-from-source). + + +vLLM-Omni is built based on vLLM. Please install it with command below. +```bash +uv pip install vllm==0.14.0 --torch-backend=auto +``` + +#### Installation of vLLM-Omni + +```bash +uv pip install vllm-omni +``` + +# --8<-- [end:pre-built-wheels] + +# --8<-- [start:build-wheel-from-source] + +#### Installation of vLLM +If you do not need to modify source code of vLLM, you can directly install the stable 0.14.0 release version of the library + +```bash +uv pip install vllm==0.14.0 --torch-backend=auto +``` + +The release 0.14.0 of vLLM is based on PyTorch 2.9.0 which requires CUDA 12.9 environment. + +#### Installation of vLLM-Omni +Since vllm-omni is rapidly evolving, it's recommended to install it from source +```bash +git clone https://github.com/vllm-project/vllm-omni.git +cd vllm-omni +uv pip install -e . +``` + +
(Optional) Installation of vLLM from source +If you want to check, modify or debug with source code of vLLM, install the library from source with the following instructions: + +```bash +git clone https://github.com/vllm-project/vllm.git +cd vllm +git checkout v0.14.0 +``` +Set up environment variables to get pre-built wheels. If there are internet problems, just download the whl file manually. And set `VLLM_PRECOMPILED_WHEEL_LOCATION` as your local absolute path of whl file. +```bash +export VLLM_PRECOMPILED_WHEEL_LOCATION=https://github.com/vllm-project/vllm/releases/download/v0.14.0/vllm-0.14.0-cp38-abi3-manylinux_2_31_x86_64.whl +``` +Install vllm with command below (If you have no existing PyTorch). +```bash +uv pip install --editable . +``` +Install vllm with command below (If you already have PyTorch). +```bash +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation --editable . +``` +
+ +# --8<-- [end:build-wheel-from-source] + +# --8<-- [start:build-wheel-from-source-in-docker] + +# --8<-- [end:build-wheel-from-source-in-docker] + +# --8<-- [start:pre-built-images] + +vLLM-Omni offers an official docker image for deployment. These images are built on top of vLLM docker images and available on Docker Hub as [vllm/vllm-omni](https://hub.docker.com/r/vllm/vllm-omni/tags). The version of vLLM-Omni indicates which release of vLLM it is based on. + +Here's an example deployment command that has been verified on 2 x H100's: +```bash +docker run --runtime nvidia --gpus 2 \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HF_TOKEN=$HF_TOKEN" \ + -p 8091:8091 \ + --ipc=host \ + vllm/vllm-omni:v0.14.0 \ + --model Qwen/Qwen3-Omni-30B-A3B-Instruct --port 8091 +``` + +!!! tip + You can use this docker image to serve models the same way you would with in vLLM! To do so, make sure you overwrite the default entrypoint (`vllm serve --omni`) which works only for models supported in the vLLM-Omni project. + +# --8<-- [end:pre-built-images] diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md new file mode 100644 index 0000000000000000000000000000000000000000..1a8ffb612849bdd3969d7c32ea4028040790146f --- /dev/null +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -0,0 +1,105 @@ +# --8<-- [start:requirements] + +- GPU: Validated on gfx942 (It should be supported on the AMD GPUs that are supported by vLLM.) + +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] + +vLLM-Omni current recommends the steps in under setup through Docker Images. + +# --8<-- [start:pre-built-wheels] + +# --8<-- [end:pre-built-wheels] + +# --8<-- [start:build-wheel-from-source] + +# --8<-- [end:build-wheel-from-source] + +# --8<-- [start:build-docker] + +#### Build docker image + +```bash +DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-omni-rocm . +``` + +#### Launch the docker image + +##### Launch with OpenAI API Server + +``` +docker run --rm \ +--group-add=video \ +--ipc=host \ +--cap-add=SYS_PTRACE \ +--security-opt seccomp=unconfined \ +--device /dev/kfd \ +--device /dev/dri \ +-v ~/.cache/huggingface:/root/.cache/huggingface \ +--env "HF_TOKEN=$HF_TOKEN" \ +-p 8091:8091 \ +--ipc=host \ +vllm-omni-rocm \ +--model Qwen/Qwen3-Omni-30B-A3B-Instruct --port 8091 +``` + +##### Launch with interactive session for development + +``` +docker run --rm -it \ +--network=host \ +--group-add=video \ +--ipc=host \ +--cap-add=SYS_PTRACE \ +--security-opt seccomp=unconfined \ +--device /dev/kfd \ +--device /dev/dri \ +-v :/app/model \ +-v ~/.cache/huggingface:/root/.cache/huggingface \ +--entrypoint bash \ +vllm-omni-rocm +``` + +# --8<-- [end:build-docker] + +# --8<-- [start:pre-built-images] + +vLLM-Omni offers an official docker image for deployment. These images are built on top of vLLM docker images and available on Docker Hub as [vllm/vllm-omni-rocm](https://hub.docker.com/r/vllm/vllm-omni-rocm/tags). The version of vLLM-Omni indicates which release of vLLM it is based on. + +#### Launch vLLM-Omni Server +Here's an example deployment command that has been verified on 2 x MI300's: +```bash +docker run --rm \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device /dev/kfd \ + --device /dev/dri \ + -v :/app/model \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HF_TOKEN=$HF_TOKEN" \ + -p 8091:8091 \ + vllm/vllm-omni-rocm:v0.14.0 \ + --model Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + +#### Launch an interactive terminal with prebuilt docker image. +If you want to run in dev environment you can launch the docker image as follows: +```bash +docker run --rm -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device /dev/kfd \ + --device /dev/dri \ + -v :/app/model \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HF_TOKEN=$HF_TOKEN" \ + --entrypoint bash \ + vllm/vllm-omni-rocm:v0.14.0 +``` + +# --8<-- [end:pre-built-images] diff --git a/docs/getting_started/installation/npu.md b/docs/getting_started/installation/npu.md new file mode 100644 index 0000000000000000000000000000000000000000..197bcec305ba877bfa5ea6c79b670883fc7fd41e --- /dev/null +++ b/docs/getting_started/installation/npu.md @@ -0,0 +1,23 @@ +# NPU + +vLLM-Omni supports NPU through the vLLM Ascend Plugin (vllm-ascend). This is a community maintained hardware plugin for running vLLM on NPU. + +## Requirements + +- OS: Linux +- Python: 3.12 + +!!! note + vLLM-Omni is currently not natively supported on Windows. + +=== "NPU" + + --8<-- "docs/getting_started/installation/npu/npu.inc.md:requirements" + +## Installation + +### Recommended + +=== "NPU" + + --8<-- "docs/getting_started/installation/npu/npu.inc.md:installation" diff --git a/docs/getting_started/installation/npu/npu.inc.md b/docs/getting_started/installation/npu/npu.inc.md new file mode 100644 index 0000000000000000000000000000000000000000..9044bb0898b229bdbaa943ba06d3d7a264d4f01b --- /dev/null +++ b/docs/getting_started/installation/npu/npu.inc.md @@ -0,0 +1,59 @@ +# --8<-- [start:requirements] + +For detailed hardware and software requirements, please refer to the [vllm-ascend installation documentation](https://docs.vllm.ai/projects/ascend/en/latest/installation.html). + +# --8<-- [end:requirements] +# --8<-- [start:installation] + +The recommended way to use vLLM-Omni on NPU is through the vllm-ascend pre-built Docker images: + +```bash +# Update DEVICE according to your NPUs (/dev/davinci[0-7]) +export DEVICE0=/dev/davinci0 +export DEVICE1=/dev/davinci1 +# Update the vllm-ascend image +# Atlas A2: +# export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0 +# Atlas A3: +# export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0-a3 +export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0 +docker run --rm \ + --name vllm-omni-npu \ + --shm-size=1g \ + --device $DEVICE0 \ + --device $DEVICE1 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v /root/.cache:/root/.cache \ + -p 8000:8000 \ + -it $IMAGE bash + +# Install the missing dependency of mooncake in the origin image. +apt update +apt install libjemalloc2 +echo "export LD_PRELOAD=/usr/lib/$(uname -m)-linux-gnu/libjemalloc.so.2:$LD_PRELOAD" >> ~/.bashrc +source ~/.bashrc + +# Inside the container, install vLLM-Omni from source +cd /vllm-workspace +git clone -b v0.14.0 https://github.com/vllm-project/vllm-omni.git +cd vllm-omni +pip install -v -e . +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +# (Optional) Disable mooncake for stable capability +mv /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake \ + /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake.disabled +``` + +The default workdir is `/workspace`, with vLLM, vLLM-Ascend and vLLM-Omni code placed in `/vllm-workspace` installed in development mode. + +For other installation methods (pip installation, building from source, custom Docker builds), please refer to the [vllm-ascend installation guide](https://docs.vllm.ai/projects/ascend/en/latest/installation.html). + +# --8<-- [end:installation] diff --git a/docs/getting_started/installation/python_env_setup.inc.md b/docs/getting_started/installation/python_env_setup.inc.md new file mode 100644 index 0000000000000000000000000000000000000000..06794f8d3120e5cf8d5f43d73f88f175454967dc --- /dev/null +++ b/docs/getting_started/installation/python_env_setup.inc.md @@ -0,0 +1,6 @@ +It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands: + +```bash +uv venv --python 3.12 --seed +source .venv/bin/activate +``` diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md new file mode 100644 index 0000000000000000000000000000000000000000..d4087621ad53e5235f1aa74dff8e5f52895997ef --- /dev/null +++ b/docs/getting_started/quickstart.md @@ -0,0 +1,116 @@ +# Quickstart + +This guide will help you quickly get started with vLLM-Omni to perform: + +- Offline batched inference +- Online serving using OpenAI-compatible server + +## Prerequisites + +- OS: Linux +- Python: 3.12 + +## Installation + +For installation on GPU from source: + +```bash +uv venv --python 3.12 --seed +source .venv/bin/activate + +# On CUDA +uv pip install vllm==0.14.0 --torch-backend=auto + +# On ROCm +uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700 + +git clone https://github.com/vllm-project/vllm-omni.git +cd vllm-omni +uv pip install -e . +``` + +For additional installation methods — please see the [installation guide](installation/README.md). + +## Offline Inference + +Text-to-image generation quickstart with vLLM-Omni: + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Tongyi-MAI/Z-Image-Turbo") + prompt = "a cup of coffee on the table" + outputs = omni.generate(prompt) + images = outputs[0].request_output[0].images + images[0].save("coffee.png") +``` + +You can pass a list of prompts and wait for them to process altogether, shown below. + +!!! info + + However, it is not currently recommended to do so + because not all models support batch inference, + and batch requesting mostly does not provide significant performance improvement (despite the impression that it does). + This feature is primarily for the sake of interface compatibility with vLLM and to allow for future improvements. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + # stage_configs_path="./stage-config.yaml", # See below + ) + prompts = [ + "a cup of coffee on a table", + "a toy dinosaur on a sandy beach", + "a fox waking up in bed and yawning", + ] + omni_outputs = omni.generate(prompts) + for i_prompt, prompt_output in enumerate(omni_outputs): + this_request_output = prompt_output.request_output[0] + this_images = this_request_output.images + for i_image, image in enumerate(this_images): + image.save(f"p{i_prompt}-img{i_image}.jpg") + print("saved to", f"p{i_prompt}-img{i_image}.jpg") + # saved to p0-img0.jpg + # saved to p1-img0.jpg + # saved to p2-img0.jpg +``` + +!!! info + + For diffusion pipelines, the stage config field `stage_args.[].runtime.max_batch_size` is 1 by default, and the input + list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support + batched inputs, you can [modify this configuration](../configuration/stage_configs.md) to let the model accept a longer batch of prompts. + +For more usages, please refer to [offline inference](../user_guide/examples/offline_inference/qwen2_5_omni.md) + +## Online Serving with OpenAI-Completions API + +Text-to-image generation quickstart with vLLM-Omni: + +```bash +vllm serve Tongyi-MAI/Z-Image-Turbo --omni --port 8091 +``` + +```bash +curl -s http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "a cup of coffee on the table"} + ], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 4.0, + "seed": 42 + } + }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2 | base64 -d > coffee.png +``` + +For more details, please refer to [online serving](../user_guide/examples/online_serving/text_to_image.md). diff --git a/docs/mkdocs/hooks/generate_api_readme.py b/docs/mkdocs/hooks/generate_api_readme.py new file mode 100644 index 0000000000000000000000000000000000000000..9c344d3f14d4f8c6af3cd83ed04b1c26d81f5aa9 --- /dev/null +++ b/docs/mkdocs/hooks/generate_api_readme.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project +""" +Hook to automatically generate docs/api/README.md from the codebase. + +This script scans the vllm_omni module for public classes and functions, +categorizes them, and generates a summary README file. +""" + +import ast +import logging +from pathlib import Path + +logger = logging.getLogger("mkdocs") + +ROOT_DIR = Path(__file__).parent.parent.parent.parent +API_README_PATH = ROOT_DIR / "docs" / "api" / "README.md" + +# Category mappings: module prefix -> category name and description +CATEGORIES = { + "entrypoints": { + "name": "Entry Points", + "description": "Main entry points for vLLM-Omni inference and serving.", + }, + "inputs": { + "name": "Inputs", + "description": "Input data structures for multi-modal inputs.", + }, + "engine": { + "name": "Engine", + "description": "Engine classes for offline and online inference.", + }, + "core": { + "name": "Core", + "description": "Core scheduling and caching components.", + }, + # "model_executor": { + # "name": "Model Executor", + # "description": "Model execution components.", + # }, + "config": { + "name": "Configuration", + "description": "Configuration classes.", + }, + "worker": { + "name": "Workers", + "description": "Worker classes and model runners for distributed inference.", + }, +} + + +class APIVisitor(ast.NodeVisitor): + """AST visitor to extract public classes and module-level functions.""" + + def __init__(self, module_path: str): + self.module_path = module_path + self.classes: list[str] = [] + self.functions: list[str] = [] + self._class_stack: list[str] = [] # Track nested class definitions + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + """Visit class definitions.""" + if not node.name.startswith("_"): + self.classes.append(f"{self.module_path}.{node.name}") + # Track that we're entering a class + self._class_stack.append(node.name) + self.generic_visit(node) + # Remove from stack when done visiting + self._class_stack.pop() + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Visit function definitions - only collect module-level functions.""" + # Only collect if we're not inside a class (stack is empty) + if not self._class_stack and not node.name.startswith("_"): + self.functions.append(f"{self.module_path}.{node.name}") + self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + """Visit async function definitions - only collect module-level functions.""" + # Only collect if we're not inside a class (stack is empty) + if not self._class_stack and not node.name.startswith("_"): + self.functions.append(f"{self.module_path}.{node.name}") + self.generic_visit(node) + + +def parse_file_for_symbols(file_path: Path, module_path: str) -> tuple[list[str], list[str]]: + """ + Parse a Python file and extract public classes and functions. + + Returns: + Tuple of (classes, functions) + """ + try: + # If this is __init__.py, use parent module path + if file_path.name == "__init__.py": + # Remove .__init__ from module path + if module_path.endswith(".__init__"): + module_path = module_path[:-9] + + with open(file_path, encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=str(file_path)) + visitor = APIVisitor(module_path) + visitor.visit(tree) + + return visitor.classes, visitor.functions + except Exception as e: + logger.debug(f"Could not parse {file_path}: {e}") + return [], [] + + +def scan_package(package_name: str = "vllm_omni") -> dict[str, list[str]]: + """ + Scan the vllm_omni package and categorize public symbols. + + Returns: + Dict mapping category names to lists of symbol full names + """ + categorized: dict[str, list[str]] = {cat["name"]: [] for cat in CATEGORIES.values()} + + try: + # Find the package directory + package_path = ROOT_DIR / package_name + if not package_path.exists(): + logger.warning(f"Package path not found: {package_path}") + return categorized + + # Walk through all Python files + for py_file in package_path.rglob("*.py"): + # Skip __init__.py and private modules + if py_file.name.startswith("_") and py_file.name != "__init__.py": + continue + + # Get module path + relative_path = py_file.relative_to(ROOT_DIR) + module_path = str(relative_path.with_suffix("")).replace("/", ".").replace("\\", ".") + + # Skip excluded modules (avoid importing vllm during docs build) + excluded_prefixes = [ + "vllm_omni.diffusion.models.qwen_image", + "vllm_omni.entrypoints.async_diffusion", + "vllm_omni.entrypoints.openai", + ] + if any(module_path.startswith(prefix) for prefix in excluded_prefixes): + continue + + # Handle __init__.py - use parent module path + if py_file.name == "__init__.py": + # Remove .__init__ from module path + if module_path.endswith(".__init__"): + module_path = module_path[:-9] + + # Determine category from module path + category = None + for prefix, cat_info in CATEGORIES.items(): + if prefix in module_path: + category = cat_info["name"] + break + + if not category: + continue + + # Parse file for symbols + classes, functions = parse_file_for_symbols(py_file, module_path) + + # Filter out internal implementation classes + # Skip classes that look like internal components (DiT layers, etc.) + internal_patterns = [ + "Block", + "Layer", + "Net", + "Embedding", + "Norm", + "Activation", + "Solver", + "Pooling", + "Attention", + "MLP", + "DecoderLayer", + "InputEmbedding", + "TimestepEmbedding", + "CodecEmbedding", + "DownSample", + "UpSample", + "Res2Net", + "SqueezeExcitation", + "TimeDelay", + "TorchActivation", + "SnakeBeta", + "SinusPosition", + "RungeKutta", + "AMPBlock", + "AdaLayerNorm", + ] + + # Add classes (filter out internal ones) + for class_name in classes: + class_short_name = class_name.split(".")[-1] + # Skip if it matches internal patterns (unless it's a main model class) + if any(pattern in class_short_name for pattern in internal_patterns): + # But include main model classes + if not any( + main_class in class_short_name + for main_class in [ + "ForConditionalGeneration", + "Model", + "Registry", + "Worker", + "Runner", + "Scheduler", + "Manager", + "Processor", + "Config", + ] + ): + continue + categorized[category].append(class_name) + + # Add important functions (parse, preprocess, etc.) + for func_name in functions: + # Include functions that match certain patterns + if any(keyword in func_name.lower() for keyword in ["parse", "preprocess"]): + categorized[category].append(func_name) + + # Sort symbols within each category + for category in categorized: + categorized[category].sort() + + except Exception as e: + logger.error(f"Error scanning package: {e}", exc_info=True) + + return categorized + + +def generate_readme(categorized: dict[str, list[str]]) -> str: + """Generate the API README markdown content.""" + lines = ["# Summary", ""] + + # Generate sections for each category + for prefix, cat_info in CATEGORIES.items(): + category_name = cat_info["name"] + description = cat_info["description"] + symbols = categorized.get(category_name, []) + + if not symbols: + continue + + lines.append(f"## {category_name}") + lines.append("") + lines.append(description) + lines.append("") + + for symbol in symbols: + lines.append(f"- [{symbol}][]") + + lines.append("") + + return "\n".join(lines) + + +def on_startup(command, dirty: bool): + """MkDocs hook entry point.""" + logger.info("Generating API README documentation") + + # Scan the package + categorized = scan_package() + + # Generate README content + content = generate_readme(categorized) + + # Write to file + API_README_PATH.parent.mkdir(parents=True, exist_ok=True) + with open(API_README_PATH, "w", encoding="utf-8") as f: + f.write(content) + + logger.info(f"API README generated: {API_README_PATH.relative_to(ROOT_DIR)}") diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..4e840280b26537605dc22fd7c6479521f3c632d5 --- /dev/null +++ b/docs/mkdocs/hooks/generate_examples.py @@ -0,0 +1,336 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal + +import regex as re +import yaml + +logger = logging.getLogger("mkdocs") + +ROOT_DIR = Path(__file__).parent.parent.parent.parent +ROOT_DIR_RELATIVE = "../../../../.." +EXAMPLE_DIR = ROOT_DIR / "examples" +EXAMPLE_DOC_DIR = ROOT_DIR / "docs/user_guide/examples" +NAV_FILE = ROOT_DIR / "docs/.nav.yml" + + +def fix_case(text: str) -> str: + subs = { + "api": "API", + "cli": "CLI", + "cpu": "CPU", + "llm": "LLM", + "mae": "MAE", + "tpu": "TPU", + "gguf": "GGUF", + "lora": "LoRA", + "rlhf": "RLHF", + "vllm": "vLLM", + "openai": "OpenAI", + "lmcache": "LMCache", + "multilora": "MultiLoRA", + "mlpspeculator": "MLPSpeculator", + r"fp\d+": lambda x: x.group(0).upper(), # e.g. fp16, fp32 + r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16 + } + for pattern, repl in subs.items(): + text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE) + return text + + +@dataclass +class Example: + """ + Example class for generating documentation content from a given path. + + Attributes: + path (Path): The path to the main directory or file. + category (str): The category of the document. + main_file (Path): The main file in the directory. + other_files (list[Path]): list of other files in the directory. + title (str): The title of the document. + + Methods: + __post_init__(): Initializes the main_file, other_files, and title attributes. + determine_main_file() -> Path: Determines the main file in the given path. + determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. + determine_title() -> str: Determines the title of the document. + generate() -> str: Generates the documentation content. + """ # noqa: E501 + + path: Path + category: str = None + main_file: Path = field(init=False) + other_files: list[Path] = field(init=False) + title: str = field(init=False) + + def __post_init__(self): + self.main_file = self.determine_main_file() + self.other_files = self.determine_other_files() + self.title = self.determine_title() + + @property + def is_code(self) -> bool: + return self.main_file.suffix != ".md" + + def determine_main_file(self) -> Path: + """ + Determines the main file in the given path. + If the path is a file, it returns the path itself. Otherwise, it searches + for Markdown files (*.md) in the directory and returns the first one found. + Returns: + Path: The main file path, either the original path if it's a file or the first + Markdown file found in the directory. + Raises: + IndexError: If no Markdown files are found in the directory. + """ # noqa: E501 + return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop() + + def determine_other_files(self) -> list[Path]: + """ + Determine other files in the directory excluding the main file. + + This method checks if the given path is a file. If it is, it returns an empty list. + Otherwise, it recursively searches through the directory and returns a list of all + files that are not the main file. + + Returns: + list[Path]: A list of Path objects representing the other files in the directory. + """ # noqa: E501 + if self.path.is_file(): + return [] + # Binary file extensions to exclude + binary_extensions = { + ".wav", + ".mp3", + ".mp4", + ".avi", + ".mov", + ".mkv", # Audio/Video + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".ico", + ".svg", # Images + ".pdf", + ".zip", + ".tar", + ".gz", + ".bz2", + ".xz", # Archives/Documents + ".exe", + ".so", + ".dll", + ".dylib", # Binaries + ".bin", + ".dat", + ".db", + ".sqlite", # Data files + } + + def is_other_file(file: Path) -> bool: + return file.is_file() and file != self.main_file and file.suffix.lower() not in binary_extensions + + return [file for file in self.path.rglob("*") if is_other_file(file)] + + def determine_title(self) -> str: + if not self.is_code: + # Specify encoding for building on Windows + with open(self.main_file, encoding="utf-8") as f: + first_line = f.readline().strip() + match = re.match(r"^#\s+(?P.+)$", first_line) + if match: + return match.group("title") + return fix_case(self.path.stem.replace("_", " ").title()) + + def fix_relative_links(self, content: str) -> str: + """ + Fix relative links in markdown content by converting them to gh-file + format. + + Args: + content (str): The markdown content to process + + Returns: + str: Content with relative links converted to gh-file format + """ + # Regex to match markdown links [text](relative_path) + # This matches links that don't start with http, https, ftp, or # + link_pattern = r"\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)" + + def replace_link(match): + link_text = match.group(1) + relative_path = match.group(2) + + # Make relative to repo root + gh_file = (self.main_file.parent / relative_path).resolve() + gh_file = gh_file.relative_to(ROOT_DIR) + + # Make GitHub URL + url = "https://github.com/vllm-project/vllm-omni/" + url += "tree/main" if self.path.is_dir() else "blob/main" + gh_url = f"{url}/{gh_file}" + + return f"[{link_text}]({gh_url})" + + return re.sub(link_pattern, replace_link, content) + + def generate(self) -> str: + content = f"# {self.title}\n\n" + url = "https://github.com/vllm-project/vllm-omni/" + url += "tree/main" if self.path.is_dir() else "blob/main" + content += f"Source <{url}/{self.path.relative_to(ROOT_DIR)}>.\n\n" + + # Use long code fence to avoid issues with + # included files containing code fences too + code_fence = "``````" + + if self.is_code: + main_file_rel = self.main_file.relative_to(ROOT_DIR) + content += f'{code_fence}{self.main_file.suffix[1:]}\n--8<-- "{main_file_rel}"\n{code_fence}\n' + else: + with open(self.main_file, encoding="utf-8") as f: + # Skip the title from md snippets as it's been included above + main_content = f.readlines()[1:] + content += self.fix_relative_links("".join(main_content)) + content += "\n" + + if not self.other_files: + return content + + content += "## Example materials\n\n" + for file in sorted(self.other_files): + content += f'??? abstract "{file.relative_to(self.path)}"\n' + if file.suffix != ".md": + content += f" {code_fence}{file.suffix[1:]}\n" + content += f' --8<-- "{file.relative_to(ROOT_DIR)}"\n' + if file.suffix != ".md": + content += f" {code_fence}\n" + + return content + + +def update_nav_file(examples: list[Example]): + """ + Update the .nav.yml file to include all generated examples. + This function completely regenerates the examples section based on the actual + folder structure, ensuring consistency between the examples folder and nav file. + + Args: + examples: List of Example objects that have been generated + """ + if not NAV_FILE.exists(): + logger.warning("Navigation file not found: %s", NAV_FILE) + return + + # Read the current nav file + with open(NAV_FILE, encoding="utf-8") as f: + nav_data = yaml.safe_load(f) or {} + + nav_list = nav_data.get("nav", []) + + # Find the "User Guide" section + user_guide_idx = None + examples_idx = None + for i, item in enumerate(nav_list): + if isinstance(item, dict) and "User Guide" in item: + user_guide_idx = i + user_guide_content = item["User Guide"] + # Find the "Examples" subsection + for j, subitem in enumerate(user_guide_content): + if isinstance(subitem, dict) and "Examples" in subitem: + examples_idx = j + break + break + + if user_guide_idx is None or examples_idx is None: + logger.warning("Could not find 'User Guide' -> 'Examples' section in nav file") + return + + # Get existing Examples section to preserve non-example items (like README.md) + existing_examples_content = nav_list[user_guide_idx]["User Guide"][examples_idx]["Examples"] + + # Preserve string items (like "examples/README.md") that are not example categories + preserved_items = [ + item + for item in existing_examples_content + if isinstance(item, str) and not item.startswith("user_guide/examples/") + ] + + # Group examples by category + examples_by_category = {} + for example in examples: + category = example.category + if category not in examples_by_category: + examples_by_category[category] = [] + examples_by_category[category].append(example) + + # Build the new Examples section - start with preserved items + examples_section = preserved_items.copy() + + # Add examples grouped by category, sorted by category name + for category in sorted(examples_by_category.keys()): + category_examples = sorted(examples_by_category[category], key=lambda e: e.path.stem) + category_items = [] + for example in category_examples: + doc_path = EXAMPLE_DOC_DIR / example.category / f"{example.path.stem}.md" + rel_path = doc_path.relative_to(ROOT_DIR / "docs") + category_items.append({example.title: str(rel_path)}) + + if category_items: + # Format category name (e.g., "offline_inference" -> "Offline Inference") + category_title = fix_case(category.replace("_", " ").title()) + examples_section.append({category_title: category_items}) + + # Update the nav structure + nav_list[user_guide_idx]["User Guide"][examples_idx]["Examples"] = examples_section + + # Write back to file + nav_data["nav"] = nav_list + with open(NAV_FILE, "w", encoding="utf-8") as f: + yaml.dump(nav_data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + logger.info("Updated navigation file: %s", NAV_FILE.relative_to(ROOT_DIR)) + + +def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): + logger.info("Generating example documentation") + logger.debug("Root directory: %s", ROOT_DIR.resolve()) + logger.debug("Example directory: %s", EXAMPLE_DIR.resolve()) + logger.debug("Example document directory: %s", EXAMPLE_DOC_DIR.resolve()) + + # Create the EXAMPLE_DOC_DIR if it doesn't exist + if not EXAMPLE_DOC_DIR.exists(): + EXAMPLE_DOC_DIR.mkdir(parents=True) + + categories = sorted(p for p in EXAMPLE_DIR.iterdir() if p.is_dir()) + + examples = [] + glob_patterns = ["*.py", "*.md", "*.sh"] + # Find categorised examples + for category in categories: + globs = [category.glob(pattern) for pattern in glob_patterns] + for path in itertools.chain(*globs): + examples.append(Example(path, category.stem)) + # Find examples in subdirectories + for path in category.glob("*/*.md"): + examples.append(Example(path.parent, category.stem)) + + # Generate the example documentation + for example in sorted(examples, key=lambda e: e.path.stem): + example_name = f"{example.path.stem}.md" + doc_path = EXAMPLE_DOC_DIR / example.category / example_name + if not doc_path.parent.exists(): + doc_path.parent.mkdir(parents=True) + # Specify encoding for building on Windows + with open(doc_path, "w+", encoding="utf-8") as f: + f.write(example.generate()) + logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR)) + + # Update the navigation file + update_nav_file(examples) diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py new file mode 100644 index 0000000000000000000000000000000000000000..8798b11db203f9857f49d01051d6503772544655 --- /dev/null +++ b/docs/mkdocs/hooks/url_schemes.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This is basically a port of MyST parser’s external URL resolution mechanism +(https://myst-parser.readthedocs.io/en/latest/syntax/cross-referencing.html#customising-external-url-resolution) +to work with MkDocs. + +It allows Markdown authors to use GitHub shorthand links like: + + - [Text](gh-issue:123) + - <gh-pr:456> + - [File](gh-file:path/to/file.py#L10) + +These are automatically rewritten into fully qualified GitHub URLs pointing to +issues, pull requests, files, directories, or projects in the +`vllm-project/vllm-omni` repository. + +The goal is to simplify cross-referencing common GitHub resources +in project docs. +""" + +import regex as re +from mkdocs.config.defaults import MkDocsConfig +from mkdocs.structure.files import Files +from mkdocs.structure.pages import Page + + +def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, files: Files) -> str: + """ + Custom MkDocs plugin hook to rewrite special GitHub reference links + in Markdown. + + This function scans the given Markdown content for specially formatted + GitHub shorthand links, such as: + - `[Link text](gh-issue:123)` + - `<gh-pr:456>` + + And rewrites them into fully-qualified GitHub URLs with GitHub icons: + - `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)` + - `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)` + + Supported shorthand types: + - `gh-issue` + - `gh-pr` + - `gh-project` + - `gh-dir` + - `gh-file` + + Args: + markdown (str): The raw Markdown content of the page. + page (Page): The MkDocs page object being processed. + config (MkDocsConfig): The MkDocs site configuration. + files (Files): The collection of files in the MkDocs build. + + Returns: + str: The updated Markdown content with GitHub shorthand links replaced. + """ + gh_icon = ":octicons-mark-github-16:" + gh_url = "https://github.com" + repo_url = f"{gh_url}/vllm-project/vllm-omni" + org_url = f"{gh_url}/orgs/vllm-project" + + # Mapping of shorthand types to their corresponding GitHub base URLs + urls = { + "issue": f"{repo_url}/issues", + "pr": f"{repo_url}/pull", + "project": f"{org_url}/projects", + "dir": f"{repo_url}/tree/main", + "file": f"{repo_url}/blob/main", + } + + # Default title prefixes for auto links + titles = { + "issue": "Issue #", + "pr": "Pull Request #", + "project": "Project #", + "dir": "", + "file": "", + } + + # Regular expression to match GitHub shorthand links + scheme = r"gh-(?P<type>.+?):(?P<path>.+?)(#(?P<fragment>.+?))?" + inline_link = re.compile(r"\[(?P<title>[^\[]+?)\]\(" + scheme + r"\)") + auto_link = re.compile(f"<{scheme}>") + + def replace_inline_link(match: re.Match) -> str: + """ + Replaces a matched inline-style GitHub shorthand link + with a full Markdown link. + + Example: + [My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123) + """ + url = f"{urls[match.group('type')]}/{match.group('path')}" + if fragment := match.group("fragment"): + url += f"#{fragment}" + + return f"[{gh_icon} {match.group('title')}]({url})" + + def replace_auto_link(match: re.Match) -> str: + """ + Replaces a matched autolink-style GitHub shorthand + with a full Markdown link. + + Example: + <gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456) + """ + type = match.group("type") + path = match.group("path") + title = f"{titles[type]}{path}" + url = f"{urls[type]}/{path}" + if fragment := match.group("fragment"): + url += f"#{fragment}" + + return f"[{gh_icon} {title}]({url})" + + # Replace both inline and autolinks + markdown = inline_link.sub(replace_inline_link, markdown) + markdown = auto_link.sub(replace_auto_link, markdown) + + return markdown diff --git a/docs/mkdocs/javascript/edit_and_feedback.js b/docs/mkdocs/javascript/edit_and_feedback.js new file mode 100644 index 0000000000000000000000000000000000000000..0acbcfd965bd97ddff3f68ad24ca2f2a8f7f4cc1 --- /dev/null +++ b/docs/mkdocs/javascript/edit_and_feedback.js @@ -0,0 +1,47 @@ +/** + * edit_and_feedback.js + * + * Enhances MkDocs Material docs pages by: + * + * 1. Adding a "Question? Give us feedback" link + * below the "Edit" button. + * + * - The link opens a GitHub issue with a template, + * auto-filled with the current page URL and path. + * + * 2. Ensuring the edit button opens in a new tab + * with target="_blank" and rel="noopener". + */ +document.addEventListener("DOMContentLoaded", function () { + const url = window.location.href; + const page = document.body.dataset.mdUrl || location.pathname; + + const feedbackLink = document.createElement("a"); + feedbackLink.href = `https://github.com/vllm-project/vllm-omni/issues/new?template=100-documentation.yml&title=${encodeURIComponent( + `[Docs] Feedback for \`${page}\`` + )}&body=${encodeURIComponent(`📄 **Reference:**\n${url}\n\n📝 **Feedback:**\n_Your response_`)}`; + feedbackLink.target = "_blank"; + feedbackLink.rel = "noopener"; + feedbackLink.title = "Provide feedback"; + feedbackLink.className = "md-content__button"; + feedbackLink.innerHTML = ` + <svg + xmlns="http://www.w3.org/2000/svg" + height="24px" + viewBox="0 -960 960 960" + width="24px" + fill="currentColor" + > + <path d="M280-280h280v-80H280v80Zm0-160h400v-80H280v80Zm0-160h400v-80H280v80Zm-80 480q-33 0-56.5-23.5T120-200v-560q0-33 23.5-56.5T200-840h560q33 0 56.5 23.5T840-760v560q0 33-23.5 56.5T760-120H200Zm0-80h560v-560H200v560Zm0-560v560-560Z"/> + </svg> + `; + + const editButton = document.querySelector('.md-content__button[href*="edit"]'); + + if (editButton && editButton.parentNode) { + editButton.insertAdjacentElement("beforebegin", feedbackLink); + + editButton.setAttribute("target", "_blank"); + editButton.setAttribute("rel", "noopener"); + } + }); diff --git a/docs/mkdocs/javascript/mathjax.js b/docs/mkdocs/javascript/mathjax.js new file mode 100644 index 0000000000000000000000000000000000000000..eb89ace0695f9d57f995fa8473de6caabfafc3a7 --- /dev/null +++ b/docs/mkdocs/javascript/mathjax.js @@ -0,0 +1,20 @@ +// Enables MathJax rendering +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } + }; + + document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() + }) diff --git a/docs/mkdocs/javascript/mermaid.js b/docs/mkdocs/javascript/mermaid.js new file mode 100644 index 0000000000000000000000000000000000000000..63676de9db57633cb5fcec3bcf312dc25d3f1f13 --- /dev/null +++ b/docs/mkdocs/javascript/mermaid.js @@ -0,0 +1,21 @@ +// Initialize Mermaid for diagram rendering +mermaid.initialize({ + startOnLoad: false, + theme: 'default', + securityLevel: 'loose', + flowchart: { + useMaxWidth: true, + htmlLabels: true + } +}); + +// Render Mermaid diagrams when page content is ready +document$.subscribe(() => { + const mermaidElements = document.querySelectorAll('.mermaid'); + if (mermaidElements.length > 0) { + mermaid.run({ + querySelector: '.mermaid', + nodes: mermaidElements + }); + } +}); diff --git a/docs/mkdocs/javascript/slack_and_forum.js b/docs/mkdocs/javascript/slack_and_forum.js new file mode 100644 index 0000000000000000000000000000000000000000..320e44910aa62aa8343efe7b4e6aaf72b62f1d27 --- /dev/null +++ b/docs/mkdocs/javascript/slack_and_forum.js @@ -0,0 +1,56 @@ +/** + * slack_and_forum.js + * + * Adds a custom Slack and Forum button to the MkDocs Material header. + * + */ + +window.addEventListener('DOMContentLoaded', () => { + const headerInner = document.querySelector('.md-header__inner'); + + if (headerInner) { + const slackButton = document.createElement('button'); + slackButton.className = 'slack-button'; + slackButton.title = 'Join us on Slack'; + slackButton.style.border = 'none'; + slackButton.style.background = 'transparent'; + slackButton.style.cursor = 'pointer'; + + slackButton.innerHTML = ` + <img src="https://a.slack-edge.com/80588/marketing/img/icons/icon_slack_hash_colored.png" + style="height: 1.1rem;" + alt="Slack"> + `; + + slackButton.addEventListener('click', () => { + window.open('https://slack.vllm.ai', '_blank', 'noopener'); + }); + + const forumButton = document.createElement('button'); + forumButton.className = 'forum-button'; + forumButton.title = 'Join the Forum'; + forumButton.style.border = 'none'; + forumButton.style.background = 'transparent'; + forumButton.style.cursor = 'pointer'; + + forumButton.innerHTML = ` + <svg + xmlns="http://www.w3.org/2000/svg" + viewBox="0 -960 960 960" + fill="currentColor" + > + <path d="M817.85-198.15 698.46-317.54H320q-24.48 0-41.47-16.99T261.54-376v-11.69h424.61q25.39 0 43.47-18.08 18.07-18.08 18.07-43.46v-268.92h11.69q24.48 0 41.47 16.99 17 16.99 17 41.47v461.54ZM179.08-434.69l66.84-66.85h363.31q10.77 0 17.69-6.92 6.93-6.92 6.93-17.69v-246.77q0-10.77-6.93-17.7-6.92-6.92-17.69-6.92H203.69q-10.77 0-17.69 6.92-6.92 6.93-6.92 17.7v338.23Zm-36.93 89.46v-427.69q0-25.39 18.08-43.46 18.08-18.08 43.46-18.08h405.54q25.39 0 43.46 18.08 18.08 18.07 18.08 43.46v246.77q0 25.38-18.08 43.46-18.07 18.07-43.46 18.07H261.54L142.15-345.23Zm36.93-180.92V-797.54v271.39Z"/> + </svg> + `; + + forumButton.addEventListener('click', () => { + window.open('https://discuss.vllm.ai/', '_blank', 'noopener'); + }); + + const githubSource = document.querySelector('.md-header__source'); + if (githubSource) { + githubSource.parentNode.insertBefore(slackButton, githubSource.nextSibling); + githubSource.parentNode.insertBefore(forumButton, slackButton.nextSibling); + } + } + }); diff --git a/docs/mkdocs/overrides/main.html b/docs/mkdocs/overrides/main.html new file mode 100644 index 0000000000000000000000000000000000000000..94d9808cc760156cb7ab46e326c0267f3406125e --- /dev/null +++ b/docs/mkdocs/overrides/main.html @@ -0,0 +1 @@ +{% extends "base.html" %} diff --git a/docs/mkdocs/overrides/partials/toc-item.html b/docs/mkdocs/overrides/partials/toc-item.html new file mode 100644 index 0000000000000000000000000000000000000000..656c2158998b8613e468a3043252be2fb67b9e95 --- /dev/null +++ b/docs/mkdocs/overrides/partials/toc-item.html @@ -0,0 +1,21 @@ +<!-- Enables the use of toc_depth in document frontmatter https://github.com/squidfunk/mkdocs-material/issues/4827#issuecomment-1869812019 --> +<li class="md-nav__item"> + <a href="{{ toc_item.url }}" class="md-nav__link"> + <span class="md-ellipsis"> + {{ toc_item.title }} + </span> + </a> + + <!-- Table of contents list --> + {% if toc_item.children %} + <nav class="md-nav" aria-label="{{ toc_item.title | striptags }}"> + <ul class="md-nav__list"> + {% for toc_item in toc_item.children %} + {% if not page.meta.toc_depth or toc_item.level <= page.meta.toc_depth %} + {% include "partials/toc-item.html" %} + {% endif %} + {% endfor %} + </ul> + </nav> + {% endif %} + </li> diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css new file mode 100644 index 0000000000000000000000000000000000000000..a29352f5bfa346f32a08e2aa3088d2f38c26ba95 --- /dev/null +++ b/docs/mkdocs/stylesheets/extra.css @@ -0,0 +1,221 @@ +/* Warning for latest docs */ +.md-banner { + background-color: var(--md-warning-bg-color); + color: var(--md-warning-fg-color); +} + +/* https://christianoliff.com/blog/styling-external-links-with-an-icon-in-css/ */ +a:not(:has(svg)):not(.md-icon):not(.autorefs-external) { + align-items: center; + + &[href^="//"]::after, + &[href^="http://"]::after, + &[href^="https://"]::after { + content: ""; + width: 12px; + height: 12px; + margin-left: 4px; + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='16' height='16' stroke='gray' viewBox='0 0 16 16'%3E%3Cpath fill-rule='evenodd' d='M8.636 3.5a.5.5 0 0 0-.5-.5H1.5A1.5 1.5 0 0 0 0 4.5v10A1.5 1.5 0 0 0 1.5 16h10a1.5 1.5 0 0 0 1.5-1.5V7.864a.5.5 0 0 0-1 0V14.5a.5.5 0 0 1-.5.5h-10a.5.5 0 0 1-.5-.5v-10a.5.5 0 0 1 .5-.5h6.636a.5.5 0 0 0 .5-.5z'/%3E%3Cpath fill-rule='evenodd' d='M16 .5a.5.5 0 0 0-.5-.5h-5a.5.5 0 0 0 0 1h3.793L6.146 9.146a.5.5 0 1 0 .708.708L15 1.707V5.5a.5.5 0 0 0 1 0v-5z'/%3E%3C/svg%3E"); + background-position: center; + background-repeat: no-repeat; + background-size: contain; + display: inline-block; + } +} + +a[href*="localhost"]::after, +a[href*="127.0.0.1"]::after, + +/* Hide external link icons for all links */ +a[href^="//"]::after, +a[href^="http://"]::after, +a[href^="https://"]::after { + display: none !important; +} + +/* Light mode: darker section titles */ +body[data-md-color-scheme="default"] .md-nav__item--section > label.md-nav__link .md-ellipsis { + color: rgba(0, 0, 0, 0.7) !important; + font-weight: 700; +} + +/* Dark mode: lighter gray section titles */ +body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link .md-ellipsis { + color: rgba(255, 255, 255, 0.75) !important; + font-weight: 700; +} + +/* Custom admonitions */ +:root { + --md-admonition-icon--announcement: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" width="16" height="16"><path d="M3.25 9a.75.75 0 0 1 .75.75c0 2.142.456 3.828.733 4.653a.122.122 0 0 0 .05.064.212.212 0 0 0 .117.033h1.31c.085 0 .18-.042.258-.152a.45.45 0 0 0 .075-.366A16.743 16.743 0 0 1 6 9.75a.75.75 0 0 1 1.5 0c0 1.588.25 2.926.494 3.85.293 1.113-.504 2.4-1.783 2.4H4.9c-.686 0-1.35-.41-1.589-1.12A16.4 16.4 0 0 1 2.5 9.75.75.75 0 0 1 3.25 9Z"></path><path d="M0 6a4 4 0 0 1 4-4h2.75a.75.75 0 0 1 .75.75v6.5a.75.75 0 0 1-.75.75H4a4 4 0 0 1-4-4Zm4-2.5a2.5 2.5 0 1 0 0 5h2v-5Z"></path><path d="M15.59.082A.75.75 0 0 1 16 .75v10.5a.75.75 0 0 1-1.189.608l-.002-.001h.001l-.014-.01a5.775 5.775 0 0 0-.422-.25 10.63 10.63 0 0 0-1.469-.64C11.576 10.484 9.536 10 6.75 10a.75.75 0 0 1 0-1.5c2.964 0 5.174.516 6.658 1.043.423.151.787.302 1.092.443V2.014c-.305.14-.669.292-1.092.443C11.924 2.984 9.713 3.5 6.75 3.5a.75.75 0 0 1 0-1.5c2.786 0 4.826-.484 6.155-.957.665-.236 1.154-.47 1.47-.64.144-.077.284-.161.421-.25l.014-.01a.75.75 0 0 1 .78-.061Z"></path></svg>'); + --md-admonition-icon--important: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" width="16" height="16"><path d="M4.47.22A.749.749 0 0 1 5 0h6c.199 0 .389.079.53.22l4.25 4.25c.141.14.22.331.22.53v6a.749.749 0 0 1-.22.53l-4.25 4.25A.749.749 0 0 1 11 16H5a.749.749 0 0 1-.53-.22L.22 11.53A.749.749 0 0 1 0 11V5c0-.199.079-.389.22-.53Zm.84 1.28L1.5 5.31v5.38l3.81 3.81h5.38l3.81-3.81V5.31L10.69 1.5ZM8 4a.75.75 0 0 1 .75.75v3.5a.75.75 0 0 1-1.5 0v-3.5A.75.75 0 0 1 8 4Zm0 8a1 1 0 1 1 0-2 1 1 0 0 1 0 2Z"></path></svg>'); + --md-admonition-icon--code: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="m11.28 3.22 4.25 4.25a.75.75 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.275-.326.75.75 0 0 1 .215-.734L13.94 8l-3.72-3.72a.749.749 0 0 1 .326-1.275.75.75 0 0 1 .734.215m-6.56 0a.75.75 0 0 1 1.042.018.75.75 0 0 1 .018 1.042L2.06 8l3.72 3.72a.749.749 0 0 1-.326 1.275.75.75 0 0 1-.734-.215L.47 8.53a.75.75 0 0 1 0-1.06Z"/></svg>'); + --md-admonition-icon--console: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M0 2.75C0 1.784.784 1 1.75 1h12.5c.966 0 1.75.784 1.75 1.75v10.5A1.75 1.75 0 0 1 14.25 15H1.75A1.75 1.75 0 0 1 0 13.25Zm1.75-.25a.25.25 0 0 0-.25.25v10.5c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25V2.75a.25.25 0 0 0-.25-.25ZM7.25 8a.75.75 0 0 1-.22.53l-2.25 2.25a.749.749 0 0 1-1.275-.326.75.75 0 0 1 .215-.734L5.44 8 3.72 6.28a.749.749 0 0 1 .326-1.275.75.75 0 0 1 .734.215l2.25 2.25c.141.14.22.331.22.53m1.5 1.5h3a.75.75 0 0 1 0 1.5h-3a.75.75 0 0 1 0-1.5"/></svg>'); +} + +.md-typeset .admonition.announcement, +.md-typeset details.announcement { + border-color: rgb(255, 110, 66); +} +.md-typeset .admonition.important, +.md-typeset details.important { + border-color: rgb(239, 85, 82); +} +.md-typeset .admonition.code, +.md-typeset details.code { + border-color: #64dd17 +} +.md-typeset .admonition.console, +.md-typeset details.console { + border-color: #64dd17 +} + +.md-typeset .announcement > .admonition-title, +.md-typeset .announcement > summary { + background-color: rgb(255, 110, 66, 0.1); +} +.md-typeset .important > .admonition-title, +.md-typeset .important > summary { + background-color: rgb(239, 85, 82, 0.1); +} +.md-typeset .code > .admonition-title, +.md-typeset .code > summary { + background-color: #64dd171a; +} +.md-typeset .console > .admonition-title, +.md-typeset .console > summary { + background-color: #64dd171a; +} + +.md-typeset .announcement > .admonition-title::before, +.md-typeset .announcement > summary::before { + background-color: rgb(239, 85, 82); + -webkit-mask-image: var(--md-admonition-icon--announcement); + mask-image: var(--md-admonition-icon--announcement); +} +.md-typeset .important > .admonition-title::before, +.md-typeset .important > summary::before { + background-color: rgb(239, 85, 82); + -webkit-mask-image: var(--md-admonition-icon--important); + mask-image: var(--md-admonition-icon--important); +} +.md-typeset .code > .admonition-title::before, +.md-typeset .code > summary::before { + background-color: #64dd17; + -webkit-mask-image: var(--md-admonition-icon--code); + mask-image: var(--md-admonition-icon--code); +} +.md-typeset .console > .admonition-title::before, +.md-typeset .console > summary::before { + background-color: #64dd17; + -webkit-mask-image: var(--md-admonition-icon--console); + mask-image: var(--md-admonition-icon--console); +} + +/* Make label fully visible on hover */ +.md-content__button[href*="edit"]:hover::after { + opacity: 1; +} + +/* Hide edit button on generated docs/examples pages */ +@media (min-width: 960px) { + .md-content__button[href*="docs/examples/"] { + display: none !important; + } +} + +.md-content__button-wrapper { + position: absolute; + top: 0.6rem; + right: 0.8rem; + display: flex; + flex-direction: row; + align-items: center; + gap: 0.4rem; + z-index: 1; +} + +.md-content__button-wrapper a { + display: inline-flex; + align-items: center; + justify-content: center; + height: 24px; + width: 24px; + color: var(--md-default-fg-color); + text-decoration: none; +} + +.md-content__button-wrapper a:hover { + color: var(--md-accent-fg-color); +} + +/* Slack and Forum css */ +.slack-button, +.forum-button { + display: inline-flex; + align-items: center; + justify-content: center; + margin-left: 0.4rem; + height: 24px; +} + +.slack-button img { + height: 18px; + filter: none !important; +} + +.slack-button:hover, +.forum-button:hover { + opacity: 0.7; +} + +.forum-button svg { + height: 28px; + opacity: 0.9; + transform: translateY(2px); +} + +/* For logo css */ +[data-md-color-scheme="default"] .logo-dark { + display: none; +} + +[data-md-color-scheme="slate"] .logo-light { + display: none; +} + +/* Outline for content tabs */ +.md-typeset .tabbed-set { + border: 0.075rem solid var(--md-default-fg-color); + border-radius: 0.2rem; +} + +.md-typeset .tabbed-content { + padding: 0 0.6em; +} + +/* Hide link icon in header logo */ +.md-header__button.md-logo :is(img, svg) { + pointer-events: none; +} + +.md-header__button.md-logo::after { + display: none !important; +} + +/* Hide link icons in content tabs (tabbed content) */ +.md-typeset .tabbed-set > label::after, +.md-typeset .tabbed-set > input:checked + label::after { + display: none !important; +} + +/* Hide link icons in navigation tabs */ +.md-nav__link[href]::after, +.md-nav__item--nested > .md-nav__link::after { + display: none !important; +} + +/* Hide link icons in top navigation tabs */ +.md-tabs__link::after { + display: none !important; +} diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md new file mode 100644 index 0000000000000000000000000000000000000000..269c9261ea8514d4adc08cb437c59cc6117a84ca --- /dev/null +++ b/docs/models/supported_models.md @@ -0,0 +1,61 @@ +# Supported Models + +vLLM-Omni supports unified multimodal comprehension and generation models across various tasks. + +## Model Implementation + +If vLLM-Omni natively supports a model, its implementation can be found in <gh-file:vllm_omni/model_executor/models> and <gh-file:vllm_omni/diffusion/models>. + +## List of Supported Models for Nvidia GPU / AMD GPU + +<style> +th { + white-space: nowrap; + min-width: 0 !important; +} +</style> + +| Architecture | Models | Example HF Models | +|--------------|--------|-------------------| +| `Qwen3OmniMoeForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | +| `Qwen2_5OmniForConditionalGeneration` | Qwen2.5-Omni | `Qwen/Qwen2.5-Omni-7B`, `Qwen/Qwen2.5-Omni-3B` | +| `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` | +| `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | +| `QwenImagePipeline` | Qwen-Image-2512 | `Qwen/Qwen-Image-2512` | +| `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` | +| `QwenImageEditPlusPipeline` | Qwen-Image-Edit-2509 | `Qwen/Qwen-Image-Edit-2509` | +| `QwenImageLayeredPipeline` | Qwen-Image-Layered | `Qwen/Qwen-Image-Layered` | +|`ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | +| `WanPipeline` | Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | +| `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | +| `OvisImagePipeline` | Ovis-Image | `OvisAI/Ovis-Image` | +|`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` | +|`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` | +|`StableDiffusion3Pipeline` | Stable-Diffusion-3 | `stabilityai/stable-diffusion-3.5-medium` | +|`Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` | +|`FluxPipeline` | FLUX.1-dev | `black-forest-labs/FLUX.1-dev` | +|`StableAudioPipeline` | Stable-Audio-Open | `stabilityai/stable-audio-open-1.0` | +|`Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-CustomVoice | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | +|`Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-VoiceDesign | `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` | +|`Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-Base | `Qwen/Qwen3-TTS-12Hz-0.6B-Base` | + + +## List of Supported Models for NPU + +<style> +th { + white-space: nowrap; + min-width: 0 !important; +} +</style> + +| Architecture | Models | Example HF Models | +|--------------|--------|-------------------| +| `Qwen2_5OmniForConditionalGeneration` | Qwen2.5-Omni | `Qwen/Qwen2.5-Omni-7B`, `Qwen/Qwen2.5-Omni-3B`| +| `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | +| `QwenImagePipeline` | Qwen-Image-2512 | `Qwen/Qwen-Image-2512` | +| `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` | +| `QwenImageEditPlusPipeline` | Qwen-Image-Edit-2509 | `Qwen/Qwen-Image-Edit-2509` | +| `QwenImageLayeredPipeline` | Qwen-Image-Layered | `Qwen/Qwen-Image-Layered` | +| `QwenImageEditPlusPipeline` | Qwen-Image-Edit-2511 | `Qwen/Qwen-Image-Edit-2511` | +|`ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | diff --git a/docs/serving/image_edit_api.md b/docs/serving/image_edit_api.md new file mode 100644 index 0000000000000000000000000000000000000000..d254ac06ad7d7be714a16cd0ebd925f16ccc7530 --- /dev/null +++ b/docs/serving/image_edit_api.md @@ -0,0 +1,205 @@ +# Image Edit API + +vLLM-Omni provides an OpenAI DALL-E compatible API for image edit using diffusion models. + +Each server instance runs a single model (specified at startup via `vllm serve <model> --omni`). + +## Quick Start + +### Start the Server + +For example... + +```bash +# Qwen-Image +vllm serve Qwen/Qwen-Image-Edit-2511 --omni --port 8000 + + +### Generate Images + +**Using curl:** + +```bash +curl -s -D >(grep -i x-request-id >&2) \ + -o >(jq -r '.data[0].b64_json' | base64 --decode > gift-basket.png) \ + -X POST "http://localhost:8000/v1/images/edits" \ + -F "model=xxx" \ + -F "image=@./xx.png" \ + -F "prompt='this bear is wearing sportwear. holding a basketball, and bending one leg.'" \ + -F "size=1024x1024" \ + -F "output_format=png" +``` + + +**Using OpenAI SDK:** + +```python +import base64 +from openai import OpenAI +from pathlib import Path +client = OpenAI( + api_key="None", + base_url="http://localhost:8000/v1" +) + +input_image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/omni-assets/qwen-bear.png" + +result = client.images.edit( + image=[], + model="Qwen-Image-Edit-2511", + prompt="Change the bears in the two input images into walking together.", + size='512x512', + stream=False, + output_format='jpeg', + # url格式 + extra_body={ + "url": [input_image_url1,input_image_url], + "num_inference_steps": 50, + "guidance_scale": 1, + "seed": 777, + } +) + +image_base64 = result.data[0].b64_json +image_bytes = base64.b64decode(image_base64) + +# Save the image to a file +with open("edit_out_http.jpeg", "wb") as f: + f.write(image_bytes) +``` + +## API Reference + +### Endpoint + +``` +POST /v1/images/edits +Content-Type: multipart/form-data +``` + +### Request Parameters + +#### OpenAI Standard Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `prompt` | string | **required** | A text description of the desired image | +| `model` | string | server's model | Model to use (optional, should match server if specified) | +| `image` | string or array | **required** | The image(s) to edit. | +| `n` | integer | 1 | Number of images to generate (1-10) | +| `size` | string | "auto" | Image dimensions in WxH format (e.g., "1024x1024", "512x512"), when set to auto, it decide size from first input image. | +| `response_format` | string | "b64_json" | Response format (only "b64_json" supported) | +| `user` | string | null | User identifier for tracking | +| `output_format` | string | "png" | The format in which the generated images are returned. Must be one of "png", "jpg", "jpeg", "webp". | +| `output_compression` | integer | 100 | The compression level (0-100%) for the generated images. | +| `background` | string or null | "auto" | Allows to set transparency for the background of the generated image(s). + +#### vllm-omni Extension Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `url` | string or array | None | The image(s) to edit. | +| `negative_prompt` | string | null | Text describing what to avoid in the image | +| `num_inference_steps` | integer | model defaults | Number of diffusion steps | +| `guidance_scale` | float | model defaults | Classifier-free guidance scale (typically 0.0-20.0) | +| `true_cfg_scale` | float | model defaults | True CFG scale (model-specific parameter, may be ignored if not supported) | +| `seed` | integer | null | Random seed for reproducibility | + +### Response Format + +```json +{ + "created": 1701234567, + "data": [ + { + "b64_json": "<base64-encoded PNG>", + "url": null, + "revised_prompt": null + } + ], + "output_format": null, + "size": null, +} +``` + +## Examples + +### Multiple Images input + +```bash +curl -s -D >(grep -i x-request-id >&2) \ + -o >(jq -r '.data[0].b64_json' | base64 --decode > gift-basket.png) \ + -X POST "http://localhost:8000/v1/images/edits" \ + -F "model=xxx" \ + -F "image=@xx.png" \ + -F "image=@xx.png" + -F "prompt='this bear is wearing sportwear. holding a basketball, and bending one leg.'" \ + -F "size=1024x1024" \ + -F "output_format=png" +``` + + +## Parameter Handling + +The API passes parameters directly to the diffusion pipeline without model-specific transformation: + +- **Default values**: When parameters are not specified, the underlying model uses its own defaults +- **Pass-through design**: User-provided values are forwarded directly to the diffusion engine +- **Minimal validation**: Only basic type checking and range validation at the API level + +### Parameter Compatibility + +The API passes parameters directly to the diffusion pipeline without model-specific validation. + +- Unsupported parameters may be silently ignored by the model +- Incompatible values will result in errors from the underlying pipeline +- Recommended values vary by model - consult model documentation + +**Best Practice:** Start with the model's recommended parameters, then adjust based on your needs. + +## Error Responses + +### 400 Bad Request + +Invalid parameters (e.g., model mismatch): + +```json +{ + "detail": "Invalid size format: '1024x'. Expected format: 'WIDTHxHEIGHT' (e.g., '1024x1024')." +} +``` + +### 422 Unprocessable Entity + +Validation errors (missing required fields): + +```json +{ + "detail": "Field 'image' or 'url' is required" +} +``` + +## Troubleshooting + +### Server Not Running + +```bash +# Check if server is responding +curl -X http://localhost:8000/v1/images/edit \ + -F "prompt='test'" +``` + +### Out of Memory + +If you encounter OOM errors: +1. Reduce image size: `"size": "512x512"` +2. Reduce inference steps: `"num_inference_steps": 25` + +## Development + +Enable debug logging to see prompts and generation details: + +```bash +vllm serve Qwen/Qwen-Image-Edit-2511 --omni \ + --uvicorn-log-level debug +``` diff --git a/docs/serving/image_generation_api.md b/docs/serving/image_generation_api.md new file mode 100644 index 0000000000000000000000000000000000000000..747cef99567bc92e4f99983823a48f176ffe1992 --- /dev/null +++ b/docs/serving/image_generation_api.md @@ -0,0 +1,249 @@ +# Image Generation API + +vLLM-Omni provides an OpenAI DALL-E compatible API for text-to-image generation using diffusion models. + +Each server instance runs a single model (specified at startup via `vllm serve <model> --omni`). + +## Quick Start + +### Start the Server + +For example... + +```bash +# Qwen-Image +vllm serve Qwen/Qwen-Image --omni --port 8000 + +# Z-Image Turbo +vllm serve Tongyi-MAI/Z-Image-Turbo --omni --port 8000 +``` + +### Generate Images + +**Using curl:** + +```bash +curl -X POST http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "a dragon laying over the spine of the Green Mountains of Vermont", + "size": "1024x1024", + "seed": 42 + }' | jq -r '.data[0].b64_json' | base64 -d > dragon.png +``` + +**Using Python:** + +```python +import requests +import base64 +from PIL import Image +import io + +response = requests.post( + "http://localhost:8000/v1/images/generations", + json={ + "prompt": "a black and white cat wearing a princess tiara", + "size": "1024x1024", + "num_inference_steps": 50, + "seed": 42, + } +) + +# Decode and save +img_data = response.json()["data"][0]["b64_json"] +img_bytes = base64.b64decode(img_data) +img = Image.open(io.BytesIO(img_bytes)) +img.save("cat.png") +``` + +**Using OpenAI SDK:** + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="none") + +response = client.images.generate( + model="Qwen/Qwen-Image", + prompt="a horse jumping over a fence nearby a babbling brook", + n=1, + size="1024x1024", + response_format="b64_json" +) + +# Note: Extension parameters (seed, steps, cfg) require direct HTTP requests +``` + +## API Reference + +### Endpoint + +``` +POST /v1/images/generations +Content-Type: application/json +``` + +### Request Parameters + +#### OpenAI Standard Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `prompt` | string | **required** | Text description of the desired image | +| `model` | string | server's model | Model to use (optional, should match server if specified) | +| `n` | integer | 1 | Number of images to generate (1-10) | +| `size` | string | model defaults | Image dimensions in WxH format (e.g., "1024x1024", "512x512") | +| `response_format` | string | "b64_json" | Response format (only "b64_json" supported) | +| `user` | string | null | User identifier for tracking | + +#### vllm-omni Extension Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `negative_prompt` | string | null | Text describing what to avoid in the image | +| `num_inference_steps` | integer | model defaults | Number of diffusion steps | +| `guidance_scale` | float | model defaults | Classifier-free guidance scale (typically 0.0-20.0) | +| `true_cfg_scale` | float | model defaults | True CFG scale (model-specific parameter, may be ignored if not supported) | +| `seed` | integer | null | Random seed for reproducibility | + +### Response Format + +```json +{ + "created": 1701234567, + "data": [ + { + "b64_json": "<base64-encoded PNG>", + "url": null, + "revised_prompt": null + } + ] +} +``` + +## Examples + +### Multiple Images + +```bash +curl -X POST http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "a steampunk city set in a valley of the Adirondack mountains", + "n": 4, + "size": "1024x1024", + "seed": 123 + }' +``` + +This generates 4 images in a single request. + +### With Negative Prompt + +```python +response = requests.post( + "http://localhost:8000/v1/images/generations", + json={ + "prompt": "a portrait of a skier in deep powder snow", + "negative_prompt": "blurry, low quality, distorted, ugly", + "num_inference_steps": 100, + "size": "1024x1024", + } +) +``` + +## Parameter Handling + +The API passes parameters directly to the diffusion pipeline without model-specific transformation: + +- **Default values**: When parameters are not specified, the underlying model uses its own defaults +- **Pass-through design**: User-provided values are forwarded directly to the diffusion engine +- **Minimal validation**: Only basic type checking and range validation at the API level + +### Parameter Compatibility + +The API passes parameters directly to the diffusion pipeline without model-specific validation. + +- Unsupported parameters may be silently ignored by the model +- Incompatible values will result in errors from the underlying pipeline +- Recommended values vary by model - consult model documentation + +**Best Practice:** Start with the model's recommended parameters, then adjust based on your needs. + +## Error Responses + +### 400 Bad Request + +Invalid parameters (e.g., model mismatch): + +```json +{ + "detail": "Invalid size format: '1024x'. Expected format: 'WIDTHxHEIGHT' (e.g., '1024x1024')." +} +``` + +### 422 Unprocessable Entity + +Validation errors (missing required fields): + +```json +{ + "detail": [ + { + "loc": ["body", "prompt"], + "msg": "field required", + "type": "value_error.missing" + } + ] +} +``` + +### 503 Service Unavailable + +Diffusion engine not initialized: + +```json +{ + "detail": "Diffusion engine not initialized. Start server with a diffusion model." +} +``` + +## Troubleshooting + +### Server Not Running + +```bash +# Check if server is responding +curl http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{"prompt": "test"}' +``` + +### Out of Memory + +If you encounter OOM errors: +1. Reduce image size: `"size": "512x512"` +2. Reduce inference steps: `"num_inference_steps": 25` +3. Generate fewer images: `"n": 1` + +## Testing + +Run the test suite to verify functionality: + +```bash +# All image generation tests +pytest tests/entrypoints/openai_api/test_image_server.py -v + +# Specific test +pytest tests/entrypoints/openai_api/test_image_server.py::test_generate_single_image -v +``` + +## Development + +Enable debug logging to see prompts and generation details: + +```bash +vllm serve Qwen/Qwen-Image --omni \ + --uvicorn-log-level debug +``` diff --git a/docs/source/architecture/ar-dit-main-architecture.png b/docs/source/architecture/ar-dit-main-architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..200fddf2b7fbab7d64c156867931ae753c2d974f Binary files /dev/null and b/docs/source/architecture/ar-dit-main-architecture.png differ diff --git a/docs/source/architecture/ar-main-architecture.png b/docs/source/architecture/ar-main-architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..642c3c008688a1dd16181ffb8426732b9cd7278d Binary files /dev/null and b/docs/source/architecture/ar-main-architecture.png differ diff --git a/docs/source/architecture/dit-main-architecture.png b/docs/source/architecture/dit-main-architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..97672e145ef8ca808cd9dcab565bf211607d356c Binary files /dev/null and b/docs/source/architecture/dit-main-architecture.png differ diff --git a/docs/source/architecture/omni-modality-model-architecture.png b/docs/source/architecture/omni-modality-model-architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..53978b32366ac9c701053c7d41b5192c600b7c31 Binary files /dev/null and b/docs/source/architecture/omni-modality-model-architecture.png differ diff --git a/docs/source/architecture/vllm-omni-dataflow-between-stages.png b/docs/source/architecture/vllm-omni-dataflow-between-stages.png new file mode 100644 index 0000000000000000000000000000000000000000..cdbc9a8b7b3766aa1902cc26085c7c3ba04f4047 Binary files /dev/null and b/docs/source/architecture/vllm-omni-dataflow-between-stages.png differ diff --git a/docs/source/architecture/vllm-omni-diffusion-flow.png b/docs/source/architecture/vllm-omni-diffusion-flow.png new file mode 100644 index 0000000000000000000000000000000000000000..92a4cfe649a51a1d9a0b52975830f739b2cf8733 Binary files /dev/null and b/docs/source/architecture/vllm-omni-diffusion-flow.png differ diff --git a/docs/source/architecture/vllm-omni-main-architecture.png b/docs/source/architecture/vllm-omni-main-architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..98b7a979242b95e9a964a3150248cdb90543235a Binary files /dev/null and b/docs/source/architecture/vllm-omni-main-architecture.png differ diff --git a/docs/source/architecture/vllm-omni-user-interface.png b/docs/source/architecture/vllm-omni-user-interface.png new file mode 100644 index 0000000000000000000000000000000000000000..867b3ae2fca46044ceee1a0be74a362cb7fe65ce Binary files /dev/null and b/docs/source/architecture/vllm-omni-user-interface.png differ diff --git a/docs/source/logos/vllm-logo-only-light.ico b/docs/source/logos/vllm-logo-only-light.ico new file mode 100644 index 0000000000000000000000000000000000000000..27528ceebfff401d0516b73099381c7425aaff3a Binary files /dev/null and b/docs/source/logos/vllm-logo-only-light.ico differ diff --git a/docs/source/logos/vllm-omni-logo.png b/docs/source/logos/vllm-omni-logo.png new file mode 100644 index 0000000000000000000000000000000000000000..c054ff46fb55692ccdcfa67ad9f5661ab2849ea0 Binary files /dev/null and b/docs/source/logos/vllm-omni-logo.png differ diff --git a/docs/usage/faq.md b/docs/usage/faq.md new file mode 100644 index 0000000000000000000000000000000000000000..e109439e849445b02dc99c9afcc5fa39af60cade --- /dev/null +++ b/docs/usage/faq.md @@ -0,0 +1,29 @@ +# Frequently Asked Questions + +> Q: How many chips do I need to infer a model in vLLM-Omni? + +A: Now, we support natively disaggregated deployment for different model stages within a model. There is a restriction that one chip can only have one AutoRegressive model stage. This is because the unified KV cache management of vLLM. Stages of other types can coexist within a chip. The restriction will be resolved in later version. + +> Q: When trying to run examples, I encounter error about backend of librosa or soundfile. How to solve it? + +A: If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` + +> Q: I see GPU OOM or "free memory is less than desired GPU memory utilization" errors. How can I fix it? + +A: Refer to [GPU memory calculation and configuration](../configuration/gpu_memory_utilization.md) for guidance on tuning `gpu_memory_utilization` and related settings. + +> Q: I encounter some bugs or CI problems, which is urgent. How can I solve it? + +A: At first, you can check current [issues](https://github.com/vllm-project/vllm-omni/issues) to find possible solutions. If none of these satisfy your demand and it is urgent, please find these [volunteers](https://docs.vllm.ai/projects/vllm-omni/en/latest/community/volunteers/) for help. + +> Q: Does vLLM-Omni support AWQ or any other quantization? + +A: vLLM-Omni partitions model into several stages. For AR stages, it will reuse main logic of LLMEngine in vLLM. So current quantization supported in vLLM should be also supported in vLLM-Omni for them. But systematic verification is ongoing. For quantization for DiffusionEngine, we are working on it. Please stay tuned and welcome contribution! + +> Q: Does vLLM-Omni support multimodal streaming input and output? + +A: Not yet. We already put it on the [Roadmap](https://github.com/vllm-project/vllm-omni/issues/165). Please stay tuned! diff --git a/docs/user_guide/diffusion/cache_dit_acceleration.md b/docs/user_guide/diffusion/cache_dit_acceleration.md new file mode 100644 index 0000000000000000000000000000000000000000..fd1bd522a20225dfb87bf26c7f4037a274cfe8dc --- /dev/null +++ b/docs/user_guide/diffusion/cache_dit_acceleration.md @@ -0,0 +1,228 @@ +# Cache-DiT Acceleration Guide + +This guide explains how to use cache-dit acceleration in vLLM-Omni to speed up diffusion model inference. + +## Overview + +Cache-dit is a library that accelerates diffusion transformer models through intelligent caching mechanisms. It supports multiple acceleration techniques that can be combined for optimal performance: + +- **DBCache**: Dual Block Cache for reducing redundant computations +- **TaylorSeer**: Taylor expansion-based forecasting for faster inference +- **SCM**: Step Computation Masking for selective step computation + +## Quick Start + +### Basic Usage + +Enable cache-dit acceleration by simply setting `cache_backend="cache_dit"`. Cache-dit will use its recommended default parameters: + +```python +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# Simplest way: just enable cache-dit with default parameters +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="cache_dit", +) + +images = omni.generate( + "a beautiful landscape", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + +**Default Parameters**: When `cache_config` is not provided, cache-dit uses optimized default values. See the [Configuration Reference](#configuration-reference) section for a complete list of all parameters and their default values. + +### Custom Configuration + +To customize cache-dit settings, provide a `cache_config` dictionary, for example: + +```python +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="cache_dit", + cache_config={ + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.12, + }, +) +``` + +## Online Serving (OpenAI-Compatible) + +Enable Cache-DiT for online serving by passing `--cache-backend cache_dit` when starting the server: + +```bash +# Use Cache-DiT default (recommended) parameters +vllm serve Qwen/Qwen-Image --omni --port 8091 --cache-backend cache_dit +``` + +To customize Cache-DiT settings for online serving, pass a JSON string via `--cache-config`: + +```bash +vllm serve Qwen/Qwen-Image --omni --port 8091 \ + --cache-backend cache_dit \ + --cache-config '{"Fn_compute_blocks": 1, "Bn_compute_blocks": 0, "max_warmup_steps": 4, "residual_diff_threshold": 0.12}' +``` + +## Acceleration Methods + +For comprehensive illustration, please view cache-dit [User_Guide](https://cache-dit.readthedocs.io/en/latest/user_guide/OVERVIEWS/) + +### 1. DBCache (Dual Block Cache) + +DBCache intelligently caches intermediate transformer block outputs when the residual differences between consecutive steps are small, reducing redundant computations without sacrificing quality. + +**Key Parameters**: + +- `Fn_compute_blocks` (int, default: 1): Number of **first n** transformer blocks used to compute stable feature differences. Higher values provide more accurate caching decisions but increase computation. +- `Bn_compute_blocks` (int, default: 0): Number of **last n** transformer blocks used for additional fusion. These blocks act as an auto-scaler for approximate hidden states. +- `max_warmup_steps` (int, default: 4): Number of initial steps where caching is disabled to ensure the model learns sufficient features before caching begins. Optimized for few-step distilled models. +- `residual_diff_threshold` (float, default: 0.24): Threshold for residual difference. Higher values lead to faster performance but may reduce precision. Default uses a relatively higher threshold for more aggressive caching. +- `max_cached_steps` (int, default: -1): Maximum number of cached steps. Set to -1 for unlimited caching. +- `max_continuous_cached_steps` (int, default: 3): Maximum number of consecutive cached steps. Limits consecutive caching to prevent precision degradation. + +**Example Configuration**: + +```python +cache_config={ + "Fn_compute_blocks": 8, # Use first 8 blocks for difference computation + "Bn_compute_blocks": 0, # No additional fusion blocks + "max_warmup_steps": 8, # Cache after 8 warmup steps + "residual_diff_threshold": 0.12, # Higher threshold for faster inference + "max_cached_steps": -1, # No limit on cached steps +} +``` + +**Performance Tips**: + +- Default `Fn_compute_blocks=1` works well for most cases. Increase to 8-12 for larger models or when more accuracy is needed +- Increase `residual_diff_threshold` (e.g., 0.12-0.15) for faster inference with slight quality trade-off, or decrease from default 0.24 for higher quality +- Default `max_warmup_steps=4` is optimized for few-step models. Increase to 6-8 for more steps if needed + +### 2. TaylorSeer + +TaylorSeer uses Taylor expansion to forecast future hidden states, allowing the model to skip some computation steps while maintaining quality. + +**Key Parameters**: + +- `enable_taylorseer` (bool, default: False): Enable TaylorSeer acceleration +- `taylorseer_order` (int, default: 1): Order of Taylor expansion. Higher orders provide better accuracy but require more computation. + +**Example Configuration**: + +```python +cache_config={ + "enable_taylorseer": True, + "taylorseer_order": 1, # First-order Taylor expansion +} +``` + +**Performance Tips**: + +- Use `taylorseer_order=1` for most cases (good balance of speed and quality) +- Combine with DBCache for maximum acceleration +- Higher orders (2-3) may improve quality but reduce speed gains + +### 3. SCM (Step Computation Masking) + +SCM allows you to specify which steps must be computed and which can use cached results, similar to LeMiCa/EasyCache style acceleration. + +**Key Parameters**: + +- `scm_steps_mask_policy` (str | None, default: None): Predefined mask policy. Options: + - `None`: SCM disabled (default) + - `"slow"`: More compute steps, higher quality (18 compute steps out of 28) + - `"medium"`: Balanced (15 compute steps out of 28) + - `"fast"`: More cache steps, faster inference (11 compute steps out of 28) + - `"ultra"`: Maximum speed (8 compute steps out of 28) +- `scm_steps_policy` (str, default: "dynamic"): Policy for cached steps: + - `"dynamic"`: Use dynamic cache for masked steps (recommended) + - `"static"`: Use static cache for masked steps + +**Example Configuration**: + +```python +cache_config={ + "scm_steps_mask_policy": "medium", # Balanced speed/quality + "scm_steps_policy": "dynamic", # Use dynamic cache +} +``` + +**Performance Tips**: + +- SCM is disabled by default (`scm_steps_mask_policy=None`). Enable it by setting a policy value if you need additional acceleration +- Start with `"medium"` policy and adjust based on quality requirements +- Use `"fast"` or `"ultra"` for maximum speed when quality can be slightly compromised +- `"dynamic"` policy generally provides better quality than `"static"` +- SCM mask is automatically regenerated when `num_inference_steps` changes during inference + +## Configuration Reference + +### DiffusionCacheConfig Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `Fn_compute_blocks` | int | 1 | First n blocks for difference computation (optimized for single-transformer models) | +| `Bn_compute_blocks` | int | 0 | Last n blocks for fusion | +| `max_warmup_steps` | int | 4 | Steps before caching starts (optimized for few-step distilled models) | +| `max_cached_steps` | int | -1 | Max cached steps (-1 = unlimited) | +| `max_continuous_cached_steps` | int | 3 | Max consecutive cached steps (prevents precision degradation) | +| `residual_diff_threshold` | float | 0.24 | Residual difference threshold (higher for more aggressive caching) | +| `num_inference_steps` | int \| None | None | Initial inference steps for SCM mask generation (optional, auto-refreshed during inference) | +| `enable_taylorseer` | bool | False | Enable TaylorSeer acceleration (not suitable for few-step distilled models) | +| `taylorseer_order` | int | 1 | Taylor expansion order | +| `scm_steps_mask_policy` | str \| None | None | SCM mask policy (None, "slow", "medium", "fast", "ultra") | +| `scm_steps_policy` | str | "dynamic" | SCM computation policy ("dynamic" or "static") | + +## Example: Accelerate Text-to-Image Generation with CacheDiT + +See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example with cache-dit acceleration. + +```bash +# Enable cache-dit with hybrid acceleration +cd examples/offline_inference/text_to_image +python text_to_image.py \ + --model Qwen/Qwen-Image \ + --prompt "a cup of coffee on the table" \ + --cache_backend cache_dit \ + --num_inference_steps 50 +``` + + +The script uses cache-dit acceleration with a hybrid configuration combining DBCache, SCM, and TaylorSeer: + +```python +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="cache_dit", + cache_config={ + # Scheme: Hybrid DBCache + SCM + TaylorSeer + # DBCache + "Fn_compute_blocks": 8, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.12, + # TaylorSeer + "enable_taylorseer": True, + "taylorseer_order": 1, + # SCM + "scm_steps_mask_policy": "fast", # Set to None to disable SCM + "scm_steps_policy": "dynamic", + }, +) +``` + +You can customize the configuration by modifying the `cache_config` dictionary to use only specific methods (e.g., DBCache only, DBCache + SCM, etc.) based on your quality and speed requirements. + +To test another model, you can modify `--model` with the target model identifier like `Tongyi-MAI/Z-Image-Turbo` and update `cache_config` according the model architecture (e.g., number of transformer blocks). + + +## Additional Resources + +- [Cache-DiT User Guide](https://cache-dit.readthedocs.io/en/latest/user_guide/OVERVIEWS/) +- [Cache-DiT Benchmark](https://cache-dit.readthedocs.io/en/latest/benchmark/HYBRID_CACHE/) +- [DBCache Technical Details](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/) diff --git a/docs/user_guide/diffusion/cpu_offload_diffusion.md b/docs/user_guide/diffusion/cpu_offload_diffusion.md new file mode 100644 index 0000000000000000000000000000000000000000..1f82fd6089a2794e3c974fab75c99ac6a61105a1 --- /dev/null +++ b/docs/user_guide/diffusion/cpu_offload_diffusion.md @@ -0,0 +1,101 @@ +# CPU Offloading for Diffusion Model + +## Overview + +vLLM-Omni provides two offloading strategies to reduce GPU memory usage for diffusion models, allowing you to run larger models on GPUs with limited VRAM: + +1. **Model-level (Component) Offloading**: Swaps entire model components (DiT transformer, VAE, encoders) between GPU and CPU. +2. **Layerwise (Blockwise) Offloading**: Keeps only a single or a few transformer blocks on GPU at a time, with compute - memory copy overlap. + +Both approaches use pinned memory for faster CPU-GPU transfers. For now, the two offloading strategies could not be used at the same time. + + +## Model-level CPU Offloading + +### Implementation + +CPU offload lets the diffusion worker move large model components between GPU and CPU memory on demand. It keeps the DiT transformer resident on GPU only while it is actively running, and swaps it out when encoders modules need the device. This reduces peak VRAM usage so bigger checkpoints run on smaller GPUs, or multiple requests can share the same GPU. + +**Execution Flow**: +1. Text encoders run on GPU while the DiT transformer is offloaded to CPU. +2. Before denoising, weights are prefetched back to GPU, honoring pinned-memory copies for speed. +3. After the diffusion step, the transformer returns to CPU and the process repeats as needed. + +Transfers use pinned host buffers, and the worker coordinates swaps via mutex-style hooks so components never compete for memory. + +### Configuration +You can enable CPU offload in two ways: + +1. **Python API**: set `enable_cpu_offload=True`. + +```python +from vllm_omni import Omni + +if __name__ == "__main__": + + m = Omni(model="Qwen/Qwen-Image",enable_cpu_offload=True) +``` + +2. **CLI**: pass `--enable-cpu-offload` to the diffusion service entrypoint. + +### Limitations +- Cold start latency increases for over one minute for some models(e.g., Qwen-Image) + + +## Layerwise (Blockwise) Offloading + +### Implementation +Layerwise offload operates at transformer block granularity, keeping a single transformer block, or a specified number of blocks, on GPU while others stay in CPU memory. + +Unlike full model-wise CPU offload which swaps entire components like DiT and encoders, layerwise offloading applies a sliding window way of loading and offloading weights between gpu and cpu: while block `i` computes, block `i+1` get prefetched asynchronously via pinned memory. In this way, only partial blocks(s) reside on GPU at any moment during inference, so that greatly decrease the memory occupancy. + +**Execution Flow**: + +1. During model initialization, all components are loaded to CPU first. Then components other than DiT model(s) in the pipeline, such as VAE and encoders, are moved to GPU. The weights of target transformer blocks are collected as contiguous tensors per layer on CPU with pinned memory; and non-block modules (embeddings, norms, etc) in the DiT model are moved to and stay on GPU. +2. The first block(s) are transferred to GPU during initialization of `LayerwiseOffloader`, before the first denoising step of the very first request. +3. As each block executes, the next block prefetches on a separate CUDA stream for compute - memory copy overlap. After execution, the current block is immediately freed from GPU memory. +4. When the last block completes, the first block prefetches for the next denoising step. + + +Example of hook executions of a DiT model with n layers, by default keep a single layer on GPU: +| Layer (block) idx | forward pre-hook | forward | forward post-hook | +|-------------------|--------------------------------|------------------|---------------------------| +| layer-0 | prefetch layer 1 (copy stream) | compute layer 0 | free layer-0 gpu weights | +| layer-1 | prefetch layer 2 (copy stream) | compute layer 1 | free layer-1 gpu weights | +| layer-2 | prefetch layer 3 (copy stream) | compute layer 2 | free layer-2 gpu weights | +| ... | ... | ... | ... | +| layer-(n-1) | **prefetch layer 0 (copy stream)** | compute layer (n-1) | free layer (n-1) gpu weights | + + +### Configuration + +1. **Python API**: set `enable_layerwise_offload=True` and optionally `layerwise_num_gpu_layers`. + +```python +from vllm_omni import Omni + +if __name__ == "__main__": + m = Omni( + model="Wan-AI/Wan2.2-T2V-A14B-Diffusers", + enable_layerwise_offload=True, + ... + ) +``` + +2. **CLI**: pass `--enable-layerwise-offload` and `--layerwise-num-gpu-layers` to the diffusion service entrypoint. + +### Supported Models + +| Architecture | Models | Example HF Models | DiT Model Cls | Blocks Attr Name | +|--------------|--------|-------------------|----------|----------| +| `QwenImagePipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image` | `QwenImageTransformer2DModel` | "transformer_blocks" | +| `Wan22Pipeline` | Wan2.2 | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | `WanTransformer3DModel` | "blocks" | + +NOTE: Models must define `_layerwise_offload_blocks_attr` class attribute so that the layerwise offloader finds the target transformer blocks. + +### Limitations +- Cold start latency increases because of + 1) components are loaded to CPU first at the very first during initialization, + 2) weight consolidation and pinning +- Performance depends on CPU <-> GPU interconnection (e.g., PCIe bandwidth). +- Support single GPU only for now diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md new file mode 100644 index 0000000000000000000000000000000000000000..6e2c18d64c8f5663e26801e627d534ce9bbc3b3d --- /dev/null +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -0,0 +1,567 @@ +# Parallelism Acceleration Guide + +This guide includes how to use parallelism methods in vLLM-Omni to speed up diffusion model inference as well as reduce the memory requirement on each device. + +## Overview + +The following parallelism methods are currently supported in vLLM-Omni: + +1. DeepSpeed Ulysses Sequence Parallel (DeepSpeed Ulysses-SP) ([arxiv paper](https://arxiv.org/pdf/2309.14509)): Ulysses-SP splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. + +2. [Ring-Attention](#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded + +3. Classifier-Free-Guidance Parallel (CFG-Parallel): CFG-Parallel runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step. + +4. [Tensor Parallelism](#tensor-parallelism): Tensor parallelism shards model weights across devices. This can reduce per-GPU memory usage. Note that for diffusion models we currently shard the majority of layers within the DiT. + +The following table shows which models are currently supported by parallelism method: + +### ImageGen + +| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | +|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ✅ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ✅ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | +| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ❌ | ✅ | + +!!! note "TP Limitations for Diffusion Models" + We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. + + - Good news: The text_encoder typically has minimal impact on overall inference performance. + - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste. + + We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771). + + +!!! note "Why Z-Image is TP=2 only" + Z-Image Turbo is currently limited to `tensor_parallel_size` of **1 or 2** due to model shape divisibility constraints. + For example, the model has `n_heads=30` and a final projection out dimension of `64`, so valid TP sizes must divide both 30 and 64; the only common divisors are **1 and 2**. + +### VideoGen + +| Model | Model Identifier | Ulysses-SP | Ring-SP | Tensor-Parallel | +|-------|------------------|------------|---------|--------------------------| +| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ❌ | + +### Tensor Parallelism + +Tensor parallelism splits model parameters across GPUs. In vLLM-Omni, tensor parallelism is configured via `DiffusionParallelConfig.tensor_parallel_size`. + +#### Offline Inference + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + parallel_config=DiffusionParallelConfig(tensor_parallel_size=2), +) + +outputs = omni.generate( + "a cat reading a book", + OmniDiffusionSamplingParams( + num_inference_steps=9, + width=512, + height=512, + ), +) +``` + +### Sequence Parallelism + +#### Ulysses-SP + +##### Offline Inference + +An example of offline inference script using [Ulysses-SP](https://arxiv.org/pdf/2309.14509) is shown below: +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=2) +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) +``` + +See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example. + +##### Online Serving + +You can enable Ulysses-SP in online serving for diffusion models via `--usp`: + +```bash +# Text-to-image (requires >= 2 GPUs) +vllm serve Qwen/Qwen-Image --omni --port 8091 --usp 2 +``` + +##### Benchmarks +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends. + +| Configuration | Ulysses degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 112.5s | 1.0x | +| Ulysses-SP | 2 | 65.2s | 1.73x | +| Ulysses-SP | 4 | 39.6s | 2.84x | +| Ulysses-SP | 8 | 30.8s | 3.65x | + +#### Ring-Attention + +Ring-Attention ([arxiv paper](https://arxiv.org/abs/2310.01889)) splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results. Unlike Ulysses-SP which uses all-to-all communication, Ring-Attention keeps the sequence dimension sharded throughout the computation and circulates Key/Value blocks through a ring topology. + +##### Offline Inference + +An example of offline inference script using Ring-Attention is shown below: +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.diffusion.data import DiffusionParallelConfig +ring_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ring_degree=2) +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) +``` + +See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example. + + +##### Online Serving + +You can enable Ring-Attention in online serving for diffusion models via `--ring`: + +```bash +# Text-to-image (requires >= 2 GPUs) +vllm serve Qwen/Qwen-Image --omni --port 8091 --ring 2 +``` + +##### Benchmarks +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**1024x1024** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA A100 GPUs. `flash_attn` is the attention backends. + +| Configuration | Ring degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 45.2s | 1.0x | +| Ring-Attention | 2 | 29.9s | 1.51x | +| Ring-Attention | 4 | 23.3s | 1.94x | + + +#### Hybrid Ulysses + Ring + +You can combine both Ulysses-SP and Ring-Attention for larger scale parallelism. The total sequence parallel size equals `ulysses_degree × ring_degree`. + +##### Offline Inference + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.diffusion.data import DiffusionParallelConfig + +# Hybrid: 2 Ulysses × 2 Ring = 4 GPUs total +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=2, ring_degree=2) +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) +``` + +##### Online Serving + +```bash +# Text-to-image (requires >= 4 GPUs) +vllm serve Qwen/Qwen-Image --omni --port 8091 --usp 2 --ring 2 +``` + +##### Benchmarks +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**1024x1024** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA A100 GPUs. `flash_attn` is the attention backends. + +| Configuration | Ulysses degree | Ring degree | Generation Time | Speedup | +|---------------|----------------|-------------|-----------------|---------| +| **Baseline (diffusers)** | - | - | 45.2s | 1.0x | +| Hybrid Ulysses + Ring | 2 | 2 | 24.3s | 1.87x | + + +##### How to parallelize a new model + +NOTE: "Terminology: SP vs CP" + Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in the [diffusers library](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py). + We use "Sequence Parallelism" to align with vLLM-Omni's terminology. + +--- + +###### Non-intrusive `_sp_plan` (Recommended) + +The `_sp_plan` mechanism allows SP without modifying `forward()` logic. The framework automatically registers hooks to shard inputs and gather outputs at module boundaries. + +**Requirements for `forward()` function:** + +- Tensor operations that need sharding/gathering must happen at **`nn.Module` boundaries** (not inline Python operations) +- If your `forward()` contains inline tensor operations (e.g., `torch.cat`, `pad_sequence`) that need sharding, **extract them into a submodule** + +**When to create a submodule:** + +```python +# ❌ BAD: Inline operations - hooks cannot intercept +def forward(self, x, cap_feats): + unified = torch.cat([x, cap_feats], dim=1) # Cannot be sharded via _sp_plan + ... + +# ✅ GOOD: Extract into a submodule +class UnifiedPrepare(nn.Module): + def forward(self, x, cap_feats): + return torch.cat([x, cap_feats], dim=1) # Now can be sharded via _sp_plan + +class MyModel(nn.Module): + def __init__(self): + self.unified_prepare = UnifiedPrepare() # Submodule + + def forward(self, x, cap_feats): + unified = self.unified_prepare(x, cap_feats) # Hook can intercept here +``` + +--- + +###### Defining `_sp_plan` + +**Type definitions** (see [diffusers `_modeling_parallel.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py) for reference): + +```python +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, # Corresponds to diffusers' ContextParallelInput + SequenceParallelOutput, # Corresponds to diffusers' ContextParallelOutput +) +``` + +| Parameter | Description | +|-----------|-------------| +| `split_dim` | Dimension to split/gather (usually `1` for sequence) | +| `expected_dims` | Expected tensor rank for validation (optional) | +| `split_output` | `False`: shard **input** parameters; `True`: shard **output** tensors | +| `auto_pad` | Auto-pad if sequence not divisible by world_size (Ulysses only) | + +**Key naming convention:** + +| Key | Meaning | Python equivalent | +|-----|---------|-------------------| +| `""` | Root model | `model` | +| `"blocks.0"` | First element of ModuleList | `model.blocks[0]` | +| `"blocks.*"` | All elements of ModuleList | `for b in model.blocks` | +| `"outputs.main"` | ModuleDict entry | `model.outputs["main"]` | + +**Dictionary key types:** + +| Key type | `split_output` | Description | +|----------|----------------|-------------| +| `"param_name"` (str) | `False` | Shard **input parameter** by name | +| `0`, `1` (int) | `True` | Shard **output tuple** by index | + +**Example** (similar to [diffusers `transformer_wan.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan.py)): + +```python +class MyTransformer(nn.Module): + _sp_plan = { + # Shard rope module OUTPUTS (returns tuple) + "rope": { + 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # cos + 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # sin + }, + # Shard transformer block INPUT parameter + "blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + # Gather at final projection + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } +``` + +--- + +###### Hook flow + +``` +Input → [SequenceParallelSplitHook: pre_forward] → Module.forward() → [post_forward] → ... + ↓ +... → [SequenceParallelGatherHook: post_forward] → Output +``` + +1. **SplitHook** shards tensors before/after the target module +2. **Attention layers** handle Ulysses/Ring communication internally +3. **GatherHook** collects sharded outputs + +The framework automatically applies these hooks when `sequence_parallel_size > 1`. + +--- + +###### Method 2: Intrusive modification (For complex cases) + +For models with dynamic sharding logic that cannot be expressed via `_sp_plan`: + +```python +from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather + +def forward(self, hidden_states, ...): + if self.parallel_config.sequence_parallel_size > 1: + hidden_states = sp_shard(hidden_states, dim=1) + # ... computation ... + output = sp_gather(output, dim=1) + return output +``` + +--- + +###### Choosing the right approach + +| Scenario | Approach | +|----------|----------| +| Standard transformer | `_sp_plan` | +| Inline tensor ops need sharding | Extract to submodule + `_sp_plan` | +| Dynamic/conditional sharding | Intrusive modification | + + +### CFG-Parallel + +#### Offline Inference + +CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=2)`, which runs one rank for the positive branch and one rank for the negative branch. + +An example of offline inference using CFG-Parallel (image-to-image) is shown below: + +```python +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig + +image_path = "path_to_image.png" +omni = Omni( + model="Qwen/Qwen-Image-Edit", + parallel_config=DiffusionParallelConfig(cfg_parallel_size=2), +) +input_image = Image.open(image_path).convert("RGB") + +outputs = omni.generate( + { + "prompt": "turn this cat to a dog", + "negative_prompt": "low quality, blurry", + "multi_modal_data": {"image": input_image}, + }, + OmniDiffusionSamplingParams( + true_cfg_scale=4.0, + num_inference_steps=50, + ), +) +``` + +Notes: + +- CFG-Parallel is only effective when a `negative_prompt` is provided AND a guidance scale (or `cfg_scale`) is greater than 1. + +See `examples/offline_inference/image_to_image/image_edit.py` for a complete working example. +```bash +cd examples/offline_inference/image_to_image/ +python image_edit.py \ + --model "Qwen/Qwen-Image-Edit" \ + --image "qwen_image_output.png" \ + --prompt "turn this cat to a dog" \ + --negative_prompt "low quality, blurry" \ + --cfg_scale 4.0 \ + --output "edited_image.png" \ + --cfg_parallel_size 2 +``` + +#### Online Serving + +You can enable CFG-Parallel in online serving for diffusion models via `--cfg-parallel-size`: + +```bash +vllm serve Qwen/Qwen-Image-Edit --omni --port 8091 --cfg-parallel-size 2 +``` + +#### How to parallelize a pipeline + +This section describes how to add CFG-Parallel to a diffusion **pipeline**. We use the Qwen-Image pipeline (`vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py`) as the reference implementation. + +In `QwenImagePipeline`, each diffusion step runs two denoiser forward passes sequentially: + +- positive (prompt-conditioned) +- negative (negative-prompt-conditioned) + +CFG-Parallel assigns these two branches to different ranks in the **CFG group** and synchronizes the results. + +vLLM-omni provides `CFGParallelMixin` base class that encapsulates the CFG parallel logic. By inheriting from this mixin and calling its methods, pipelines can easily implement CFG parallel without writing repetitive code. + +**Key Methods in CFGParallelMixin:** +- `predict_noise_maybe_with_cfg()`: Automatically handles CFG parallel noise prediction +- `scheduler_step_maybe_with_cfg()`: Scheduler step with automatic CFG rank synchronization + +**Example Implementation:** + +```python +class QwenImageCFGParallelMixin(CFGParallelMixin): + """ + Base Mixin class for Qwen Image pipelines providing shared CFG methods. + """ + + def diffuse( + self, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + latents: torch.Tensor, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + negative_txt_seq_lens: torch.Tensor, + timesteps: torch.Tensor, + do_true_cfg: bool, + guidance: torch.Tensor, + true_cfg_scale: float, + image_latents: torch.Tensor | None = None, + cfg_normalize: bool = True, + additional_transformer_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + self.transformer.do_true_cfg = do_true_cfg + + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + # Prepare kwargs for positive (conditional) prediction + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + # Prepare kwargs for negative (unconditional) prediction + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + # - In CFG parallel mode: rank0 computes positive, rank1 computes negative + # - Automatically gathers results and combines them on rank0 + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=true_cfg_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=cfg_normalize, + ) + + # Step scheduler with automatic CFG synchronization + # - Only rank0 computes the scheduler step + # - Automatically broadcasts updated latents to all ranks + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg + ) + + return latents +``` + +**How it works:** +1. Prepare separate `positive_kwargs` and `negative_kwargs` for conditional and unconditional predictions +2. Call `predict_noise_maybe_with_cfg()` which: + - Detects if CFG parallel is enabled (`get_classifier_free_guidance_world_size() > 1`) + - Distributes computation: rank0 processes positive, rank1 processes negative + - Gathers predictions and combines them using `combine_cfg_noise()` on rank0 + - Returns combined noise prediction (only valid on rank0) +3. Call `scheduler_step_maybe_with_cfg()` which: + - Only rank0 computes the scheduler step + - Broadcasts the updated latents to all ranks for synchronization + +**How to customize** + +Some pipelines may need to customize the following functions in `CFGParallelMixin`: +1. You may need to edit `predict_noise` function for custom behaviors. +```python +def predict_noise(self, *args, **kwargs): + """ + Forward pass through transformer to predict noise. + + Subclasses should override this if they need custom behavior, + but the default implementation calls self.transformer. + """ + return self.transformer(*args, **kwargs)[0] + +``` +2. The default normalization function after combining the noise predictions from both branches is as follows. You may need to customize it. +```python +def cfg_normalize_function(self, noise_pred, comb_pred): + """ + Normalize the combined noise prediction. + + Args: + noise_pred: positive noise prediction + comb_pred: combined noise prediction after CFG + + Returns: + Normalized noise prediction tensor + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + return noise_pred +``` diff --git a/docs/user_guide/diffusion/teacache.md b/docs/user_guide/diffusion/teacache.md new file mode 100644 index 0000000000000000000000000000000000000000..40dafeb88adf360fe3b63d0161b315ea86537df4 --- /dev/null +++ b/docs/user_guide/diffusion/teacache.md @@ -0,0 +1,145 @@ +# TeaCache Configuration Guide + +TeaCache speeds up diffusion model inference by caching transformer computations when consecutive timesteps are similar. This typically provides **1.5x-2.0x speedup** with minimal quality loss. + +## Quick Start + +Enable TeaCache by setting `cache_backend` to `"tea_cache"`: + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# Simple configuration - model_type is automatically extracted from pipeline.__class__.__name__ +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="tea_cache", + cache_config={ + "rel_l1_thresh": 0.2 # Optional, defaults to 0.2 + } +) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) +``` + +### Using Environment Variable + +You can also enable TeaCache via environment variable: + +```bash +export DIFFUSION_CACHE_BACKEND=tea_cache +``` + +Then initialize without explicitly setting `cache_backend`: + +```python +from vllm_omni import Omni + +omni = Omni( + model="Qwen/Qwen-Image", + cache_config={"rel_l1_thresh": 0.2} # Optional +) +``` + +## Online Serving (OpenAI-Compatible) + +Enable TeaCache for online serving by passing `--cache-backend tea_cache` when starting the server: + +```bash +vllm serve Qwen/Qwen-Image --omni --port 8091 \ + --cache-backend tea_cache \ + --cache-config '{"rel_l1_thresh": 0.2}' +``` + +## Configuration Parameters + +### `rel_l1_thresh` (float, default: `0.2`) + +Controls the balance between speed and quality. Lower values prioritize quality, higher values prioritize speed. + +**Recommended values:** + +- `0.2` - **~1.5x speedup** with minimal quality loss (recommended) +- `0.4` - **~1.8x speedup** with slight quality loss +- `0.6` - **~2.0x speedup** with noticeable quality loss +- `0.8` - **~2.25x speedup** with significant quality loss + +## Examples + +### Python API + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="tea_cache", + cache_config={"rel_l1_thresh": 0.2} +) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) +``` + +## Performance Tuning + +Start with the default `rel_l1_thresh=0.2` and adjust based on your needs: + +- **Maximum quality**: Use `0.1-0.2` +- **Balanced**: Use `0.2-0.4` (recommended) +- **Maximum speed**: Use `0.6-0.8` (may reduce quality) + +## Troubleshooting + +### Quality Degradation + +If you notice quality issues, lower the threshold: + +```python +cache_config={"rel_l1_thresh": 0.1} # More conservative caching +``` + +## Supported Models + +### ImageGen + +<style> +th { + white-space: nowrap; + min-width: 0 !important; +} +</style> + +| Architecture | Models | Example HF Models | +|--------------|--------|-------------------| +| `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | +| `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` | +| `QwenImageEditPlusPipeline` | Qwen-Image-Edit-2509 | `Qwen/Qwen-Image-Edit-2509` | +| `QwenImageLayeredPipeline` | Qwen-Image-Layered | `Qwen/Qwen-Image-Layered` | +| `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` | + +### VideoGen + +No VideoGen models are supported by TeaCache yet. + +### Coming Soon + +<style> +th { + white-space: nowrap; + min-width: 0 !important; +} +</style> + +| Architecture | Models | Example HF Models | +|--------------|--------|-------------------| +| `FluxPipeline` | Flux | - | +| `CogVideoXPipeline` | CogVideoX | - | diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md new file mode 100644 index 0000000000000000000000000000000000000000..d081243782da96803346c63e46cae6289475b11d --- /dev/null +++ b/docs/user_guide/diffusion_acceleration.md @@ -0,0 +1,236 @@ +# Diffusion Acceleration Overview + +vLLM-Omni supports various cache acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, and **parallelism methods** that distribute the computation across multiple devices. + +## Supported Acceleration Methods + +vLLM-Omni currently supports two main cache acceleration backends: + +1. **[TeaCache](diffusion/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar +2. **[Cache-DiT](diffusion/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: + - **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences + - **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference + - **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking + +Both methods can provide significant speedups (typically **1.5x-2.0x**) while maintaining high output quality. + +vLLM-Omni also supports parallelism methods for diffusion models, including: + +1. [Ulysses-SP](diffusion/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. + +2. [Ring-Attention](diffusion/parallelism_acceleration.md#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded. + +3. [CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step. + +## Quick Comparison + +### Cache Methods + +| Method | Configuration | Description | Best For | +|--------|--------------|-------------|----------| +| **TeaCache** | `cache_backend="tea_cache"` | Simple, adaptive caching with minimal configuration | Quick setup, balanced speed/quality | +| **Cache-DiT** | `cache_backend="cache_dit"` | Advanced caching with multiple techniques (DBCache, TaylorSeer, SCM) | Maximum acceleration, fine-grained control | + +## Supported Models + +The following table shows which models are currently supported by each acceleration method: + +### ImageGen + +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | +|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | + +### VideoGen + +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel | +|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| +| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | + + +## Performance Benchmarks + +The following benchmarks were measured on **Qwen/Qwen-Image** and **Qwen/Qwen-Image-Edit** models generating 1024x1024 images with 50 inference steps: + +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. The configurations shown use default or common parameter settings and have not been exhaustively optimized for maximum performance. Actual performance may vary based on: + + - Specific model and use case + - Hardware configuration + - Careful parameter tuning + - Different inference settings (e.g., number of steps, image resolution) + + For optimal performance in your specific scenario, we recommend experimenting with different parameter configurations as described in the detailed guides below. + +| Model | Cache Backend | Cache Config | Generation Time | Speedup | Notes | +|-------|---------------|--------------|----------------|---------|-------| +| **Qwen/Qwen-Image** | None | None | 20.0s | 1.0x | Baseline (diffusers) | +| **Qwen/Qwen-Image** | TeaCache | `rel_l1_thresh=0.2` | 10.47s | **1.91x** | Recommended default setting | +| **Qwen/Qwen-Image** | Cache-DiT | DBCache + TaylorSeer (Fn=1, Bn=0, W=8, TaylorSeer order=1) | 10.8s | **1.85x** | - | +| **Qwen/Qwen-Image** | Cache-DiT | DBCache + TaylorSeer + SCM (Fn=8, Bn=0, W=4, TaylorSeer order=1, SCM fast) | 14.0s | **1.43x** | - | +| **Qwen/Qwen-Image-Edit** | None | No acceleration | 51.5s | 1.0x | Baseline (diffusers) | +| **Qwen/Qwen-Image-Edit** | Cache-DiT | Default (Fn=1, Bn=0, W=4, TaylorSeer disabled, SCM disabled) | 21.6s | **2.38x** | - | + +To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** model generating images (**2048x2048** as long sequence input) with 50 inference steps. The hardware devices are NVIDIA H800 GPUs. `sdpa` is the attention backends. + +| Configuration | Ulysses degree |Generation Time | Speedup | +|---------------|----------------|---------|---------| +| **Baseline (diffusers)** | - | 112.5s | 1.0x | +| Ulysses-SP | 2 | 65.2s | 1.73x | +| Ulysses-SP | 4 | 39.6s | 2.84x | +| Ulysses-SP | 8 | 30.8s | 3.65x | + +## Quick Start + +### Using TeaCache + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="tea_cache", + cache_config={"rel_l1_thresh": 0.2} # Optional, defaults to 0.2 +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) +``` + +### Using Cache-DiT + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="cache_dit", + cache_config={ + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 8, + "enable_taylorseer": True, + "taylorseer_order": 1, + } +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) +``` + +### Using Ulysses-SP + +Run text-to-image: +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ulysses_degree=ulysses_degree) +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) +``` + + +Run image-to-image: +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.diffusion.data import DiffusionParallelConfig +ulysses_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image-Edit", + parallel_config=DiffusionParallelConfig(ulysses_degree=ulysses_degree) +) + +outputs = omni.generate( + { + "prompt": "turn this cat to a dog", + "multi_modal_data": {"image": input_image} + }, + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + +### Using Ring-Attention + +Run text-to-image: +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.diffusion.data import DiffusionParallelConfig +ring_degree = 2 + +omni = Omni( + model="Qwen/Qwen-Image", + parallel_config=DiffusionParallelConfig(ring_degree=2) +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) +``` + +### Using CFG-Parallel + +Run image-to-image: + +CFG-Parallel splits the CFG positive/negative branches across GPUs. Use it when you set a non-trivial `true_cfg_scale`. + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.diffusion.data import DiffusionParallelConfig +cfg_parallel_size = 2 + +omni = Omni( + model="Qwen/Qwen-Image-Edit", + parallel_config=DiffusionParallelConfig(cfg_parallel_size=cfg_parallel_size) +) + +outputs = omni.generate( + { + "prompt": "turn this cat to a dog", + "multi_modal_data": {"image": input_image} + }, + OmniDiffusionSamplingParams(num_inference_steps=50, true_cfg_scale=4.0), +) +``` + +## Documentation + +For detailed information on each acceleration method: + +- **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices +- **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. +- **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. diff --git a/docs/user_guide/examples/offline_inference/bagel.md b/docs/user_guide/examples/offline_inference/bagel.md new file mode 100644 index 0000000000000000000000000000000000000000..b2ee5d2418a6ee157edf42fc623d7946cd5ef8f3 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/bagel.md @@ -0,0 +1,179 @@ +# BAGEL-7B-MoT + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/bagel>. + +## Set up + +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices. + +Get into the bagel folder + +```bash +cd examples/offline_inference/bagel +``` + +### Modality Control + +BAGEL-7B-MoT supports multiple modality modes. You can control the mode using the `--modality` argument: + +#### Text to Image (text2img) + +- **Pipeline**: Text → Thinker → DiT → VAE Decode → Image +- **Stages Used**: Stage 0 (Thinker) + Stage 1 (DiT) +- **KV Transfer**: Thinker sends KV cache to DiT for conditioned generation + +Generate images from text prompts: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2img \ + --prompts "A cute cat" +``` + +#### Image to Image (img2img) + +- **Pipeline**: Image → VAE Encode → DiT → VAE Decode → New Image +- **Stages Used**: Stage 1 (DiT) only +- **Special**: Bypasses the Thinker stage, direct image-to-image transformation + +Transform images based on text prompts: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality img2img \ + --image-path /path/to/image.jpg \ + --prompts "Let the woman wear a blue dress" +``` + +#### Image to Text (img2text) + +- **Pipeline**: Image → ViT + VAE Encode → Thinker → Text Output +- **Stages Used**: Stage 0 (Thinker) only +- **Special**: Uses both VAE latent encoding AND ViT semantic encoding for comprehensive image understanding + +Generate text descriptions from images: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality img2text \ + --image-path /path/to/image.jpg \ + --prompts "Describe this image in detail" +``` + +#### Text to Text (text2text) + +- **Pipeline**: Text → Thinker → Text Output +- **Stages Used**: Stage 0 (Thinker) only +- **Special**: No visual components involved, operates as pure language model + +Pure text generation: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2text \ + --prompts "What is the capital of France?" + +# You can load prompts from a text file (one prompt per line): +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2text \ + --txt-prompts /path/to/prompts.txt +``` + +### Inference Steps + +Control the number of inference steps for image generation: + +```bash +# You can adjust steps to 100 to improve image quality +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2img \ + --steps 50 \ + --prompts "A cute cat" +``` + +### Key arguments + +BAGEL-7B-MoT supports **multiple modality modes** for different use cases. + +The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml) + +#### 📌 Command Line Arguments (end2end.py) + +| Argument | Type | Default | Description | +| :--------------------- | :----- | :---------------------------- | :----------------------------------------------------------- | +| `--model` | string | `ByteDance-Seed/BAGEL-7B-MoT` | Model path or name | +| `--modality` | choice | `text2img` | Modality mode: `text2img`, `img2img`, `img2text`, `text2text` | +| `--prompts` | list | `None` | Input text prompts directly | +| `--txt-prompts` | string | `None` | Path to txt file with one prompt per line | +| `--image-path` | string | `None` | Input image path (for `img2img`/`img2text`) | +| `--steps` | int | `50` | Number of inference steps | +| `--stage-configs-path` | string | `None` | Custom stage config file path | +| `--worker-backend` | choice | `process` | Worker backend: `process` or `ray` | +| `--ray-address` | string | `None` | Ray cluster address | +| `--enable-stats` | flag | `False` | Enable statistics logging | +| `--init-sleep-seconds` | int | `20` | Initialization sleep time | +| `--batch-timeout` | int | `5` | Batch timeout | +| `--init-timeout` | int | `300` | Initialization timeout | + +------ + +#### ⚙️ Stage Configuration Parameters (bagel.yaml) + + **Stage 0 - Thinker (LLM Stage)** + +| Parameter | Value | Description | +| :------------------------------- | :------------------------------ | :----------------------- | +| `stage_type` | `llm` | Stage type | +| `devices` | `"0"` | GPU device ID | +| `max_batch_size` | `1` | Maximum batch size | +| `model_stage` | `thinker` | Model stage identifier | +| `model_arch` | `BagelForConditionalGeneration` | Model architecture | +| `gpu_memory_utilization` | `0.4` | GPU memory utilization | +| `tensor_parallel_size` | `1` | Tensor parallel size | +| `max_num_batched_tokens` | `32768` | Maximum batched tokens | +| `omni_kv_config.need_send_cache` | `true` | Whether to send KV cache | + +------ + +**Stage 1 - DiT (Diffusion Stage)** + +| Parameter | Value | Description | +| :------------------------------- | :---------- | :-------------------------- | +| `stage_type` | `diffusion` | Stage type | +| `devices` | `"0"` | GPU device ID | +| `max_batch_size` | `1` | Maximum batch size | +| `model_stage` | `dit` | Model stage identifier | +| `gpu_memory_utilization` | `0.4` | GPU memory utilization | +| `omni_kv_config.need_recv_cache` | `true` | Whether to receive KV cache | +| `engine_input_source` | `[0]` | Input source from Stage 0 | + +------ + +#### 🔗 Runtime Configuration + +| Parameter | Value | Description | +| :-------------------- | :------ | :------------------------------- | +| `window_size` | `-1` | Window size (-1 means unlimited) | +| `max_inflight` | `1` | Maximum inflight requests | +| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) | + +## FAQ + +- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. + +```bash +sudo apt update +sudo apt install ffmpeg +``` + +- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. + +| Stage | VRAM | +| :------------------ | :--------------------------- | +| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** | +| Stage-1 (DiT) | **26.50 GiB** | +| Total | **~42 GiB + KV Cache** | diff --git a/docs/user_guide/examples/offline_inference/image_to_image.md b/docs/user_guide/examples/offline_inference/image_to_image.md new file mode 100644 index 0000000000000000000000000000000000000000..c970106d2b30eccd39b55421a478b2ae41eae559 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/image_to_image.md @@ -0,0 +1,64 @@ +# Image-To-Image + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_image>. + + +This example edits an input image with `Qwen/Qwen-Image-Edit` using the `image_edit.py` CLI. + +## Local CLI Usage + +### Single Image Editing + +Download the example image: + +```bash +wget https://vllm-public-assets.s3.us-west-2.amazonaws.com/omni-assets/qwen-bear.png +``` + +Then run: + +```bash +python image_edit.py \ + --image qwen-bear.png \ + --prompt "Let this mascot dance under the moon, surrounded by floating stars and poetic bubbles such as 'Be Kind'" \ + --output output_image_edit.png \ + --num_inference_steps 50 \ + --cfg_scale 4.0 +``` + +### Multiple Image Editing (Qwen-Image-Edit-2509) + +For multiple image inputs, use `Qwen/Qwen-Image-Edit-2509` or `Qwen/Qwen-Image-Edit-2511`: + +```bash +python image_edit.py \ + --model Qwen/Qwen-Image-Edit-2509 \ + --image img1.png img2.png \ + --prompt "Combine these images into a single scene" \ + --output output_image_edit.png \ + --num_inference_steps 50 \ + --cfg_scale 4.0 \ + --guidance_scale 1.0 +``` + +Key arguments: + +- `--model`: model name or path. Use `Qwen/Qwen-Image-Edit-2509` or later for multiple image support. +- `--image`: path(s) to the source image(s) (PNG/JPG, converted to RGB). Can specify multiple images. +- `--prompt` / `--negative_prompt`: text description (string). +- `--cfg_scale`: true classifier-free guidance scale (default: 4.0). Classifier-free guidance is enabled by setting cfg_scale > 1 and providing a negative_prompt. Higher guidance scale encourages images closely linked to the text prompt, usually at the expense of lower image quality. +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. +- `--guidance_scale`: guidance scale for guidance-distilled models (default: 1.0, disabled). Unlike classifier-free guidance (--cfg_scale), guidance-distilled models take the guidance scale directly as an input parameter. Enabled when guidance_scale > 1. Ignored when not using guidance-distilled models. +- `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). +- `--output`: path to save the generated PNG. + +## Example materials + +??? abstract "image_edit.py" + ``````py + --8<-- "examples/offline_inference/image_to_image/image_edit.py" + `````` +??? abstract "run_qwen_image_edit_2511.sh" + ``````sh + --8<-- "examples/offline_inference/image_to_image/run_qwen_image_edit_2511.sh" + `````` diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md new file mode 100644 index 0000000000000000000000000000000000000000..d65839dd75840d91f5ab63e8fab5cafec66395bd --- /dev/null +++ b/docs/user_guide/examples/offline_inference/image_to_video.md @@ -0,0 +1,67 @@ +# Image-To-Video + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video>. + + +This example demonstrates how to generate videos from images using Wan2.2 Image-to-Video models with vLLM-Omni's offline inference API. + +## Local CLI Usage + +### Wan2.2-I2V-A14B-Diffusers (MoE) +```bash +python image_to_video.py \ + --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --image input.png \ + --prompt "A cat playing with yarn, smooth motion" \ + --negative_prompt "<optional quality filter>" \ + --height 480 \ + --width 832 \ + --num_frames 48 \ + --guidance_scale 5.0 \ + --guidance_scale_high 6.0 \ + --num_inference_steps 40 \ + --boundary_ratio 0.875 \ + --flow_shift 12.0 \ + --fps 16 \ + --output i2v_output.mp4 +``` + +### Wan2.2-TI2V-5B-Diffusers (Unified) +```bash +python image_to_video.py \ + --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --image input.png \ + --prompt "A cat playing with yarn, smooth motion" \ + --negative_prompt "<optional quality filter>" \ + --height 480 \ + --width 832 \ + --num_frames 48 \ + --guidance_scale 4.0 \ + --num_inference_steps 40 \ + --flow_shift 12.0 \ + --fps 16 \ + --output i2v_output.mp4 +``` + +Key arguments: + +- `--model`: Model ID (I2V-A14B for MoE, TI2V-5B for unified T2V+I2V). +- `--image`: Path to input image (required). +- `--prompt`: Text description of desired motion/animation. +- `--height/--width`: Output resolution (auto-calculated from image if not set). Dimensions should be multiples of 16. +- `--num_frames`: Number of frames (default 81). +- `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high-noise stages for MoE). +- `--negative_prompt`: Optional list of artifacts to suppress. +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. +- `--boundary_ratio`: Boundary split ratio for two-stage MoE models. +- `--flow_shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--num_inference_steps`: Number of denoising steps (default 50). +- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). +- `--output`: Path to save the generated video. + +## Example materials + +??? abstract "image_to_video.py" + ``````py + --8<-- "examples/offline_inference/image_to_video/image_to_video.py" + `````` diff --git a/docs/user_guide/examples/offline_inference/lora_inference.md b/docs/user_guide/examples/offline_inference/lora_inference.md new file mode 100644 index 0000000000000000000000000000000000000000..dde42655e44840874fe4455cb71f93ab3f9ceef5 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/lora_inference.md @@ -0,0 +1,107 @@ +# LoRA-Inference + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/lora_inference>. + +This contains examples for using LoRA (Low-Rank Adaptation) adapters with vLLM-omni diffusion models for offline inference. +The example uses the `stabilityai/stable-diffusion-3.5-medium` as the default model, but you can replace it with other models in vLLM-omni. + +## Overview + +Similar to vLLM, vLLM-omni uses a unified LoRA handling mechanism: + +- **Pre-loaded LoRA**: Loaded at initialization via `--lora-path` (pre-loaded into cache) +- **Per-request LoRA**: Loaded on-demand. In the example, the LoRA is loaded via `--lora-request-path` in each request + +Both approaches use the same underlying mechanism - all LoRA adapters are handled uniformly through `set_active_adapter()`. If no LoRA request is provided in a request, all adapters are deactivated. + +## Usage + +### Pre-loaded LoRA (via --lora-path) + +Load a LoRA adapter at initialization. This adapter is pre-loaded into the cache and can be activated by requests: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_preloaded.png +``` + +**Note**: When using `--lora-path`, the adapter is loaded at init time with a stable ID derived from the adapter path. This example activates it automatically for the request. + +### Per-request LoRA (via --lora-request-path) + +Load a LoRA adapter on-demand for each request: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-request-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_per_request.png +``` + +### No LoRA + +If no LoRA request is provided, we will use the base model without any LoRA adapters: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_no_lora.png +``` + +## Parameters + +### LoRA Parameters + +- `--lora-path`: Path to LoRA adapter folder to pre-load at initialization (loads into cache with a stable ID derived from the path) +- `--lora-request-path`: Path to LoRA adapter folder for per-request loading +- `--lora-request-id`: Integer ID for the LoRA adapter (optional). If not provided and `--lora-request-path` is set, will derive a stable ID from the path. +- `--lora-scale`: Scale factor for LoRA weights (default: 1.0). Higher values increase the influence of the LoRA adapter. + +### Standard Parameters + +- `--prompt`: Text prompt for image generation (required) +- `--seed`: Random seed for reproducibility (default: 42) +- `--height`: Image height in pixels (default: 1024) +- `--width`: Image width in pixels (default: 1024) +- `--num_inference_steps`: Number of denoising steps (default: 50) +- `--output`: Output file path (default: `lora_output.png`) + +## How LoRA Works + +All LoRA adapters are handled uniformly: + +1. **Initialization**: If `--lora-path` is provided, the adapter is loaded into cache with a stable ID derived from the adapter path +2. **Per-request**: If `--lora-request-path` is provided, the adapter is loaded/activated for that request +3. **No LoRA**: If no LoRA request is provided (`req.lora_request` is None), all adapters are deactivated + +The system uses LRU cache management - adapters are cached and evicted when the cache is full (unless pinned). + +## LoRA Adapter Format + +LoRA adapters must be in PEFT (Parameter-Efficient Fine-Tuning) format. A typical LoRA adapter directory structure: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` + +## Example materials + +??? abstract "lora_inference.py" + ``````py + --8<-- "examples/offline_inference/lora_inference/lora_inference.py" + `````` diff --git a/docs/user_guide/examples/offline_inference/qwen2_5_omni.md b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md new file mode 100644 index 0000000000000000000000000000000000000000..07a56cf9a0673803b3c49ca2df9bb25bc3814afb --- /dev/null +++ b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md @@ -0,0 +1,92 @@ +# Qwen2.5-Omni + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/qwen2_5_omni>. + + +## Setup +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +### Multiple Prompts +Get into the example folder +```bash +cd examples/offline_inference/qwen2_5_omni +``` +Then run the command below. Note: for processing large volume data, it uses py_generator mode, which will return a python generator from Omni class. +```bash +bash run_multiple_prompts.sh +``` + +### Single Prompt +Get into the example folder +```bash +cd examples/offline_inference/qwen2_5_omni +``` +Then run the command below. +```bash +bash run_single_prompt.sh +``` + +### Modality control +If you want to control output modalities, e.g. only output text, you can run the command below: +```bash +python end2end.py --output-wav output_audio \ + --query-type mixed_modalities \ + --modalities text +``` + +#### Using Local Media Files +The `end2end.py` script supports local media files (audio, video, image) via CLI arguments: + +```bash +# Use single local media files +python end2end.py --query-type use_image --image-path /path/to/image.jpg +python end2end.py --query-type use_video --video-path /path/to/video.mp4 +python end2end.py --query-type use_audio --audio-path /path/to/audio.wav + +# Combine multiple local media files +python end2end.py --query-type mixed_modalities \ + --video-path /path/to/video.mp4 \ + --image-path /path/to/image.jpg \ + --audio-path /path/to/audio.wav + +# Use audio from video file +python end2end.py --query-type use_audio_in_video --video-path /path/to/video.mp4 + +``` + +If media file paths are not provided, the script will use default assets. Supported query types: +- `use_image`: Image input only +- `use_video`: Video input only +- `use_audio`: Audio input only +- `mixed_modalities`: Audio + image + video +- `use_audio_in_video`: Extract audio from video +- `text`: Text-only query + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` + +## Example materials + +??? abstract "end2end.py" + ``````py + --8<-- "examples/offline_inference/qwen2_5_omni/end2end.py" + `````` +??? abstract "extract_prompts.py" + ``````py + --8<-- "examples/offline_inference/qwen2_5_omni/extract_prompts.py" + `````` +??? abstract "run_multiple_prompts.sh" + ``````sh + --8<-- "examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh" + `````` +??? abstract "run_single_prompt.sh" + ``````sh + --8<-- "examples/offline_inference/qwen2_5_omni/run_single_prompt.sh" + `````` diff --git a/docs/user_guide/examples/offline_inference/qwen3_omni.md b/docs/user_guide/examples/offline_inference/qwen3_omni.md new file mode 100644 index 0000000000000000000000000000000000000000..2e28f7dea92071f917f66309ba758371abaf5d59 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/qwen3_omni.md @@ -0,0 +1,99 @@ +# Qwen3-Omni + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/qwen3_omni>. + + +## Setup +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +### Multiple Prompts +Get into the example folder +```bash +cd examples/offline_inference/qwen3_omni +``` +Then run the command below. Note: for processing large volume data, it uses py_generator mode, which will return a python generator from Omni class. +```bash +bash run_multiple_prompts.sh +``` +### Single Prompt +Get into the example folder +```bash +cd examples/offline_inference/qwen3_omni +``` +Then run the command below. +```bash +bash run_single_prompt.sh +``` +If you have not enough memory, you can set thinker with tensor parallel. Just run the command below. +```bash +bash run_single_prompt_tp.sh +``` + +### Modality control +If you want to control output modalities, e.g. only output text, you can run the command below: +```bash +python end2end.py --output-wav output_audio \ + --query-type use_audio \ + --modalities text +``` + +#### Using Local Media Files +The `end2end.py` script supports local media files (audio, video, image) via command-line arguments: + +```bash +# Use local video file +python end2end.py --query-type use_video --video-path /path/to/video.mp4 + +# Use local image file +python end2end.py --query-type use_image --image-path /path/to/image.jpg + +# Use local audio file +python end2end.py --query-type use_audio --audio-path /path/to/audio.wav + +# Combine multiple local media files +python end2end.py --query-type mixed_modalities \ + --video-path /path/to/video.mp4 \ + --image-path /path/to/image.jpg \ + --audio-path /path/to/audio.wav +``` + +If media file paths are not provided, the script will use default assets. Supported query types: +- `use_video`: Video input +- `use_image`: Image input +- `use_audio`: Audio input +- `text`: Text-only query +- `multi_audios`: Multiple audio inputs +- `mixed_modalities`: Combination of video, image, and audio inputs + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` + +## Example materials + +??? abstract "end2end.py" + ``````py + --8<-- "examples/offline_inference/qwen3_omni/end2end.py" + `````` +??? abstract "run_multiple_prompts.sh" + ``````sh + --8<-- "examples/offline_inference/qwen3_omni/run_multiple_prompts.sh" + `````` +??? abstract "run_single_prompt.sh" + ``````sh + --8<-- "examples/offline_inference/qwen3_omni/run_single_prompt.sh" + `````` +??? abstract "run_single_prompt_tp.sh" + ``````sh + --8<-- "examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh" + `````` +??? abstract "text_prompts_10.txt" + ``````txt + --8<-- "examples/offline_inference/qwen3_omni/text_prompts_10.txt" + `````` diff --git a/docs/user_guide/examples/offline_inference/qwen3_tts.md b/docs/user_guide/examples/offline_inference/qwen3_tts.md new file mode 100644 index 0000000000000000000000000000000000000000..81bcb1c1133211ac5fb75811e917c42f2c1ca30c --- /dev/null +++ b/docs/user_guide/examples/offline_inference/qwen3_tts.md @@ -0,0 +1,94 @@ +# Qwen3-TTS Offline Inference + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/qwen3_tts>. + + +This directory contains an offline demo for running Qwen3 TTS models with vLLM Omni. It builds task-specific inputs and generates WAV files locally. + +## Model Overview + +Qwen3 TTS provides multiple task variants for speech generation: + +- **CustomVoice**: Generate speech with a known speaker identity (speaker ID) and optional instruction. +- **VoiceDesign**: Generate speech from text plus a descriptive instruction that designs a new voice. +- **Base**: Voice cloning using a reference audio + reference transcript, with optional mode selection. + +## Setup +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Quick Start + +Run a single sample for a task: + +``` +python end2end.py --query-type CustomVoice +``` + +Generated audio files are saved to `output_audio/` by default. + +## Task Usage + +### CustomVoice + +Single sample: + +``` +python end2end.py --query-type CustomVoice +``` + +Batch sample (multiple prompts in one run): + +``` +python end2end.py --query-type CustomVoice --use-batch-sample +``` + +### VoiceDesign + +Single sample: + +``` +python end2end.py --query-type VoiceDesign +``` + +Batch sample: + +``` +python end2end.py --query-type VoiceDesign --use-batch-sample +``` + +### Base (Voice Clone) + +Single sample: + +``` +python end2end.py --query-type Base +``` + +Batch sample: + +``` +python end2end.py --query-type Base --use-batch-sample +``` + +Mode selection for Base: + +- `--mode-tag icl` (default): standard mode +- `--mode-tag xvec_only`: enable `x_vector_only_mode` in the request + +Examples: + +``` +python end2end.py --query-type Base --mode-tag icl +``` + +## Notes + +- The script uses the model paths embedded in `end2end.py`. Update them if your local cache path differs. +- Use `--output-dir` (preferred) or `--output-wav` to change the output folder. + +## Example materials + +??? abstract "end2end.py" + ``````py + --8<-- "examples/offline_inference/qwen3_tts/end2end.py" + `````` diff --git a/docs/user_guide/examples/offline_inference/text_to_image.md b/docs/user_guide/examples/offline_inference/text_to_image.md new file mode 100644 index 0000000000000000000000000000000000000000..486b9f63b0613554540944202a38b3d475a60190 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/text_to_image.md @@ -0,0 +1,125 @@ +# Text-To-Image + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/text_to_image>. + + +This folder provides several entrypoints for experimenting with `Qwen/Qwen-Image` `Qwen/Qwen-Image-2512` `Tongyi-MAI/Z-Image-Turbo` using vLLM-Omni: + +- `text_to_image.py`: command-line script for single image generation with advanced options. +- `web_demo.py`: lightweight Gradio UI for interactive prompt/seed/CFG exploration. + +Note that when you pass in multiple independent prompts, they will be processed sequentially. Batching requests is currently not supported. + +## Basic Usage + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + prompt = "a cup of coffee on the table" + outputs = omni.generate(prompt) + images = outputs[0].request_output[0].images + images[0].save("coffee.png") +``` + +Or put more than one prompt in a request. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + prompts = [ + "a cup of coffee on a table", + "a toy dinosaur on a sandy beach", + "a fox waking up in bed and yawning", + ] + outputs = omni.generate(prompts) + for i, output in enumerate(outputs): + image = output.request_output[0].images[0].save(f"{i}.jpg") +``` + +!!! info + + However, it is not currently recommended to do so + because not all models support batch inference, + and batch requesting mostly does not provide significant performance improvement (despite the impression that it does). + This feature is primarily for the sake of interface compatibility with vLLM and to allow for future improvements. + +!!! info + + For diffusion pipelines, the stage config field `stage_args.[].runtime.max_batch_size` is 1 by default, and the input + list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support + batched inputs, you can [modify this configuration](../../../configuration/stage_configs.md) to let the model accept a longer batch of prompts. + +Apart from string prompt, vLLM-Omni also supports dictionary prompts in the same style as vLLM. +This is useful for models that support negative prompts. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + outputs = omni.generate([ + { + "prompt": "a cup of coffee on a table", + "negative_prompt": "low resolution" + }, + { + "prompt": "a toy dinosaur on a sandy beach", + "negative_prompt": "cinematic, realistic" + } + ]) + for i, output in enumerate(outputs): + image = output.request_output[0].images[0].save(f"{i}.jpg") +``` + +## Local CLI Usage + +```bash +python text_to_image.py \ + --model Tongyi-MAI/Z-Image-Turbo \ + --prompt "a cup of coffee on the table" \ + --seed 42 \ + --cfg_scale 4.0 \ + --num_images_per_prompt 1 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output outputs/coffee.png +``` + +Key arguments: + +- `--prompt`: text description (string). +- `--seed`: integer seed for deterministic sampling. +- `--cfg_scale`: true CFG scale (model-specific guidance strength). +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. +- `--num_images_per_prompt`: number of images to generate per prompt (saves as `output`, `output_1`, ...). +- `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). +- `--height/--width`: output resolution (defaults 1024x1024). +- `--output`: path to save the generated PNG. + +> ℹ️ Qwen-Image currently publishes best-effort presets at `1328x1328`, `1664x928`, `928x1664`, `1472x1140`, `1140x1472`, `1584x1056`, and `1056x1584`. Adjust `--height/--width` accordingly for the most reliable outcomes. + +## Web UI Demo + +Launch the gradio demo: + +```bash +python gradio_demo.py --port 7862 +``` + +Then open `http://localhost:7862/` on your local browser to interact with the web UI. + +## Example materials + +??? abstract "gradio_demo.py" + ``````py + --8<-- "examples/offline_inference/text_to_image/gradio_demo.py" + `````` +??? abstract "text_to_image.py" + ``````py + --8<-- "examples/offline_inference/text_to_image/text_to_image.py" + `````` diff --git a/docs/user_guide/examples/offline_inference/text_to_video.md b/docs/user_guide/examples/offline_inference/text_to_video.md new file mode 100644 index 0000000000000000000000000000000000000000..db0860b38e2188c0e178441d4af5f0dfd87c6ffc --- /dev/null +++ b/docs/user_guide/examples/offline_inference/text_to_video.md @@ -0,0 +1,41 @@ +# Text-To-Video + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/text_to_video>. + + +The `Wan-AI/Wan2.2-T2V-A14B-Diffusers` pipeline generates short videos from text prompts. + +## Local CLI Usage + +```bash +python text_to_video.py \ + --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --negative_prompt "<optional quality filter>" \ + --height 480 \ + --width 640 \ + --num_frames 32 \ + --guidance_scale 4.0 \ + --guidance_scale_high 3.0 \ + --num_inference_steps 40 \ + --fps 16 \ + --output t2v_out.mp4 +``` + +Key arguments: + +- `--prompt`: text description (string). +- `--height/--width`: output resolution (defaults 720x1280). Dimensions should align with Wan VAE downsampling (multiples of 8). +- `--num_frames`: Number of frames (Wan default is 81). +- `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high).. +- `--negative_prompt`: optional list of artifacts to suppress (the PR demo used a long Chinese string). +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. +- `--boundary_ratio`: Boundary split ratio for low/high DiT. +- `--fps`: frames per second for the saved MP4 (requires `diffusers` export_to_video). +- `--output`: path to save the generated video. + +## Example materials + +??? abstract "text_to_video.py" + ``````py + --8<-- "examples/offline_inference/text_to_video/text_to_video.py" + `````` diff --git a/docs/user_guide/examples/online_serving/bagel.md b/docs/user_guide/examples/online_serving/bagel.md new file mode 100644 index 0000000000000000000000000000000000000000..107e42449a2ed43a7cba14999a3ec42c3f74f63a --- /dev/null +++ b/docs/user_guide/examples/online_serving/bagel.md @@ -0,0 +1,232 @@ +# BAGEL-7B-MoT + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/bagel>. + +## 🛠️ Installation + +Please refer to [README.md](../../../README.md) + +## Run examples (BAGEL-7B-MoT) + +**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices. + +### Launch the Server + +```bash +# Use default configuration +vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 +``` + +Or use the convenience script: + +```bash +cd /workspace/vllm-omni/examples/online_serving/bagel +bash run_server.sh +``` + +If you have a custom stage configs file, launch the server with the command below: + +```bash +vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +Get into the bagel folder: + +```bash +cd examples/online_serving/bagel +``` + +Send request via Python + +```bash +python openai_chat_client.py --prompt "A cute cat" --modality text2img +``` + +The Python client supports the following command-line arguments: + +- `--prompt` (or `-p`): Text prompt for generation (default: `A cute cat`) +- `--output` (or `-o`): Output file path for image results (default: `bagel_output.png`) +- `--server` (or `-s`): Server URL (default: `http://localhost:8091`) +- `--image-url` (or `-i`): Input image URL or local file path (for img2img/img2text modes) +- `--modality` (or `-m`): Task modality (default: `text2img`). Options: `text2img`, `img2img`, `img2text`, `text2text` +- `--height`: Image height in pixels (default: 512) +- `--width`: Image width in pixels (default: 512) +- `--steps`: Number of inference steps (default: 25) +- `--seed`: Random seed (default: 42) +- `--negative`: Negative prompt for image generation + +Example with custom parameters: + +```bash +python openai_chat_client.py \ + --prompt "A futuristic city" \ + --modality text2img \ + --height 768 \ + --width 768 \ + --steps 50 \ + --seed 42 \ + --negative "blurry, low quality" +``` + +## Modality Control + +BAGEL-7B-MoT supports **multiple modality modes** for different use cases. + +The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml) + +| Modality | Input | Output | Description | +| ----------- | ------------ | ------ | -------------------------------------- | +| `text2img` | Text | Image | Generate images from text prompts | +| `img2img` | Image + Text | Image | Transform images using text guidance | +| `img2text` | Image + Text | Text | Generate text descriptions from images | +| `text2text` | Text | Text | Pure text generation | + +### Text to Image (text2img) + +Generate images from text prompts: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "A beautiful sunset over mountains" \ + --modality text2img \ + --output sunset.png \ + --steps 50 +``` + +**Using curl** + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>A beautiful sunset over mountains<|im_end|>"}]}], + "modalities": ["image"], + "height": 512, + "width": 512, + "num_inference_steps": 50, + "seed": 42 + }' +``` + + +### Image to Image (img2img) + +Transform images based on text prompts: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "Make the cat stand up" \ + --modality img2img \ + --image-url /path/to/input.jpg \ + --output transformed.png +``` + +**Using curl** + +```bash +IMAGE_BASE64=$(base64 -w 0 cat.jpg) + +cat <<EOF > payload.json +{ + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "<|im_start|>Make the cat stand up<|im_end|>"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,${IMAGE_BASE64}"}} + ] + }], + "modalities": ["image"], + "height": 512, + "width": 512, + "num_inference_steps": 50, + "seed": 42 +} +EOF + +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @payload.json + +``` + +### Image to Text (img2text) + +Generate text descriptions from images: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "Describe this image in detail" \ + --modality img2text \ + --image-url /path/to/image.jpg +``` + +**Using curl** + +```bash +IMAGE_BASE64=$(base64 -w 0 cat.jpg) + +cat <<EOF > payload.json +{ + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "<|im_start|>user\n<|image_pad|>\nDescribe this image in detail<|im_end|>\n<|im_start|>assistant\n"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,${IMAGE_BASE64}"}} + ] + }], + "modalities": ["text"] +} +EOF + +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @payload.json +``` + +### Text to Text (text2text) + +Pure text generation: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "What is the capital of France?" \ + --modality text2text +``` + +**Using curl** + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"}]}] + "modalities": ["text"] + }' +``` + +## FAQ + +- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. + +```bash +sudo apt update +sudo apt install ffmpeg +``` + +- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. + +| Stage | VRAM | +| :------------------ | :--------------------------- | +| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** | +| Stage-1 (DiT) | **26.50 GiB** | +| Total | **~42 GiB + KV Cache** | diff --git a/docs/user_guide/examples/online_serving/image_to_image.md b/docs/user_guide/examples/online_serving/image_to_image.md new file mode 100644 index 0000000000000000000000000000000000000000..b89f0fe825f411d1be468a8be8ca883fc90c03f8 --- /dev/null +++ b/docs/user_guide/examples/online_serving/image_to_image.md @@ -0,0 +1,240 @@ +# Image-To-Image + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/image_to_image>. + + +This example demonstrates how to deploy Qwen-Image-Edit model for online image editing service using vLLM-Omni. + +For **multi-image** input editing, use **Qwen-Image-Edit-2509** (QwenImageEditPlusPipeline) and send multiple images in the user message content. + +## Start Server + +### Basic Start + +```bash +vllm serve Qwen/Qwen-Image-Edit --omni --port 8092 +``` + +### Multi-Image Edit (Qwen-Image-Edit-2509) + +```bash +vllm serve Qwen/Qwen-Image-Edit-2509 --omni --port 8092 +``` + +### Start with Parameters + + +Or use the startup script: + +```bash +bash run_server.sh +``` + +To serve Qwen-Image-Edit-2509 with the script: + +```bash +MODEL=Qwen/Qwen-Image-Edit-2509 bash run_server.sh +``` + +## API Calls + +### Method 1: Using curl (Image Editing) + +```bash +# Image editing +bash run_curl_image_edit.sh input.png "Convert this image to watercolor style" + +# Or execute directly +IMG_B64=$(base64 -w0 input.png) + +cat <<EOF > request.json +{ + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Convert this image to watercolor style"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,$IMG_B64"}} + ] + }], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 1, + "seed": 42 + } +} +EOF + +curl -s http://localhost:8092/v1/chat/completions -H "Content-Type: application/json" -d @request.json | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2 | base64 -d > output.png +``` + +### Method 2: Using Python Client + +```bash +python openai_chat_client.py --input input.png --prompt "Convert to oil painting style" --output output.png + +# Multi-image editing (Qwen-Image-Edit-2509 server required) +python openai_chat_client.py --input input1.png input2.png --prompt "Combine these images into a single scene" --output output.png +``` + +### Method 3: Using Gradio Demo + +```bash +python gradio_demo.py +# Visit http://localhost:7861 +``` + +## Request Format + +### Image Editing (Using image_url Format) + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Convert this image to watercolor style"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} + ] + } + ] +} +``` + +### Image Editing (Using Simplified image Format) + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"text": "Convert this image to watercolor style"}, + {"image": "BASE64_IMAGE_DATA"} + ] + } + ] +} +``` + +### Image Editing with Parameters + +Use `extra_body` to pass generation parameters: + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Convert to ink wash painting style"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} + ] + } + ], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 7.5, + "seed": 42 + } +} +``` + +### Multi-Image Editing (Qwen-Image-Edit-2509) + +Provide multiple images in `content` (order matters): + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Combine these images into a single scene"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."} }, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."} } + ] + } + ] +} +``` + +## Generation Parameters (extra_body) + +| Parameter | Type | Default | Description | +| ------------------------ | ----- | ------- | ------------------------------------- | +| `height` | int | None | Output image height in pixels | +| `width` | int | None | Output image width in pixels | +| `size` | str | None | Output image size (e.g., "1024x1024") | +| `num_inference_steps` | int | 50 | Number of denoising steps | +| `guidance_scale` | float | 7.5 | CFG guidance scale | +| `seed` | int | None | Random seed (reproducible) | +| `negative_prompt` | str | None | Negative prompt | +| `num_outputs_per_prompt` | int | 1 | Number of images to generate | + +## Response Format + +```json +{ + "id": "chatcmpl-xxx", + "created": 1234567890, + "model": "Qwen/Qwen-Image-Edit", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": [{ + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,..." + } + }] + }, + "finish_reason": "stop" + }], + "usage": {...} +} +``` + +## Common Editing Instructions Examples + +| Instruction | Description | +| ---------------------------------------- | ---------------- | +| `Convert this image to watercolor style` | Style transfer | +| `Convert the image to black and white` | Desaturation | +| `Enhance the color saturation` | Color adjustment | +| `Convert to cartoon style` | Cartoonization | +| `Add vintage filter effect` | Filter effect | +| `Convert daytime scene to nighttime` | Scene conversion | + +## File Description + +| File | Description | +| ------------------------ | ---------------------------- | +| `run_server.sh` | Server startup script | +| `run_curl_image_edit.sh` | curl image editing example | +| `openai_chat_client.py` | Python client | +| `gradio_demo.py` | Gradio interactive interface | + +## Example materials + +??? abstract "gradio_demo.py" + ``````py + --8<-- "examples/online_serving/image_to_image/gradio_demo.py" + `````` +??? abstract "openai_chat_client.py" + ``````py + --8<-- "examples/online_serving/image_to_image/openai_chat_client.py" + `````` +??? abstract "run_curl_image_edit.sh" + ``````sh + --8<-- "examples/online_serving/image_to_image/run_curl_image_edit.sh" + `````` +??? abstract "run_server.sh" + ``````sh + --8<-- "examples/online_serving/image_to_image/run_server.sh" + `````` diff --git a/docs/user_guide/examples/online_serving/lora_inference.md b/docs/user_guide/examples/online_serving/lora_inference.md new file mode 100644 index 0000000000000000000000000000000000000000..4c8b215d299bbc1189a139ce79a4f8078a4d433d --- /dev/null +++ b/docs/user_guide/examples/online_serving/lora_inference.md @@ -0,0 +1,69 @@ +# LoRA-Inference + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/lora_inference>. + +This example shows how to use **per-request LoRA** with vLLM-Omni diffusion models via the OpenAI-compatible Chat Completions API. + +> Note: The LoRA adapter path must be readable on the **server** machine (usually a local path or a mounted directory). +> Note: This example uses `/v1/chat/completions`. LoRA payloads for other OpenAI endpoints are not implemented here. + +## Start Server + +```bash +# Pick a diffusion model (examples) +# export MODEL=stabilityai/stable-diffusion-3.5-medium +# export MODEL=Qwen/Qwen-Image + +bash run_server.sh +``` + +## Call API (curl) + +```bash +# Required: local LoRA folder on the server +export LORA_PATH=/path/to/lora_adapter + +# Optional +export SERVER=http://localhost:8091 +export PROMPT="A piece of cheesecake" +export LORA_NAME=my_lora +export LORA_SCALE=1.0 +# Optional: if omitted, the server derives a stable id from LORA_PATH. +# export LORA_INT_ID=123 + +bash run_curl_lora_inference.sh +``` + +## Call API (Python) + +```bash +python openai_chat_client.py \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora_adapter \ + --lora-name my_lora \ + --lora-scale 1.0 \ + --output output.png +``` + +## LoRA Format + +LoRA adapters should be in PEFT format, for example: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` + +??? abstract "openai_chat_client.py" + ``````py + --8<-- "examples/online_serving/lora_inference/openai_chat_client.py" + `````` +??? abstract "run_curl_lora_inference.sh" + ``````py + --8<-- "examples/online_serving/lora_inference/run_curl_lora_inference.sh" + `````` +??? abstract "run_server.sh" + ``````py + --8<-- "examples/online_serving/lora_inference/run_server.sh" + `````` diff --git a/docs/user_guide/examples/online_serving/qwen2_5_omni.md b/docs/user_guide/examples/online_serving/qwen2_5_omni.md new file mode 100644 index 0000000000000000000000000000000000000000..361044ed8fd3c06872afaece8477ad7e6aca3b96 --- /dev/null +++ b/docs/user_guide/examples/online_serving/qwen2_5_omni.md @@ -0,0 +1,237 @@ +# Qwen2.5-Omni + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/qwen2_5_omni>. + + +## 🛠️ Installation + +Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/README.md) + +## Run examples (Qwen2.5-Omni) + +### Launch the Server + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 +``` + +If you have custom stage configs file, launch the server with command below +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +Get into the example folder +```bash +cd examples/online_serving/qwen2_5_omni +``` + +#### Send request via python + +```bash +python openai_chat_completion_client_for_multimodal_generation.py --query-type mixed_modalities +``` + +The Python client supports the following command-line arguments: + +- `--query-type` (or `-q`): Query type (default: `mixed_modalities`). Options: `mixed_modalities`, `use_audio_in_video`, `multi_audios`, `text` +- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type uses video, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4` +- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type uses image, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` +- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` +- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` + + +For example, to use mixed modalities with all local files: + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --video-path /path/to/your/video.mp4 \ + --image-path /path/to/your/image.jpg \ + --audio-path /path/to/your/audio.wav \ + --prompt "Analyze all the media content and provide a comprehensive summary." +``` + +#### Send request via curl + +```bash +bash run_curl_multimodal_generation.sh mixed_modalities +``` + +## Modality control +You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. + +### Supported modalities + +| Modalities | Output | +|------------|--------| +| `["text"]` | Text only | +| `["audio"]` | Text + Audio | +| `["text", "audio"]` | Text + Audio | +| Not specified | Text + Audio (default) | + +### Using curl + +#### Text only + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2.5-Omni-7B", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["text"] + }' +``` + +#### Text + Audio + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2.5-Omni-7B", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["audio"] + }' +``` + +### Using Python client + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --modalities text +``` + +### Using OpenAI Python SDK + +#### Text only + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen2.5-Omni-7B", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["text"] +) +print(response.choices[0].message.content) +``` + +#### Text + Audio + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen2.5-Omni-7B", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["audio"] +) +# Response contains two choices: one with text, one with audio +print(response.choices[0].message.content) # Text response +print(response.choices[1].message.audio) # Audio response +``` + +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --stream +``` + +## Run Local Web UI Demo + +This Web UI demo allows users to interact with the model through a web browser. + +### Running Gradio Demo + +The Gradio demo connects to a vLLM API server. You have two options: + +#### Option 1: One-step Launch Script (Recommended) + +The convenience script launches both the vLLM server and Gradio demo together: + +```bash +./run_gradio_demo.sh --model Qwen/Qwen2.5-Omni-7B --server-port 8091 --gradio-port 7861 +``` + +This script will: +1. Start the vLLM server in the background +2. Wait for the server to be ready +3. Launch the Gradio demo +4. Handle cleanup when you press Ctrl+C + +The script supports the following arguments: +- `--model`: Model name/path (default: Qwen/Qwen2.5-Omni-7B) +- `--server-port`: Port for vLLM server (default: 8091) +- `--gradio-port`: Port for Gradio demo (default: 7861) +- `--stage-configs-path`: Path to custom stage configs YAML file (optional) +- `--server-host`: Host for vLLM server (default: 0.0.0.0) +- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1) +- `--share`: Share Gradio demo publicly (creates a public link) + +#### Option 2: Manual Launch (Two-Step Process) + +**Step 1: Launch the vLLM API server** + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 +``` + +If you have custom stage configs file: +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +**Step 2: Run the Gradio demo** + +In a separate terminal: + +```bash +python gradio_demo.py --model Qwen/Qwen2.5-Omni-7B --api-base http://localhost:8091/v1 --port 7861 +``` + +Then open `http://localhost:7861/` on your local browser to interact with the web UI. + +The gradio script supports the following arguments: + +- `--model`: Model name/path (should match the server model) +- `--api-base`: Base URL for the vLLM API server (default: http://localhost:8091/v1) +- `--ip`: Host/IP for Gradio server (default: 127.0.0.1) +- `--port`: Port for Gradio server (default: 7861) +- `--share`: Share the Gradio demo publicly (creates a public link) + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` + +## Example materials + +??? abstract "gradio_demo.py" + ``````py + --8<-- "examples/online_serving/qwen2_5_omni/gradio_demo.py" + `````` +??? abstract "openai_chat_completion_client_for_multimodal_generation.py" + ``````py + --8<-- "examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py" + `````` +??? abstract "run_curl_multimodal_generation.sh" + ``````sh + --8<-- "examples/online_serving/qwen2_5_omni/run_curl_multimodal_generation.sh" + `````` +??? abstract "run_gradio_demo.sh" + ``````sh + --8<-- "examples/online_serving/qwen2_5_omni/run_gradio_demo.sh" + `````` diff --git a/docs/user_guide/examples/online_serving/qwen3_omni.md b/docs/user_guide/examples/online_serving/qwen3_omni.md new file mode 100644 index 0000000000000000000000000000000000000000..d2624653394d4b1064038e43d6f71e2c4ace688f --- /dev/null +++ b/docs/user_guide/examples/online_serving/qwen3_omni.md @@ -0,0 +1,247 @@ +# Qwen3-Omni + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/qwen3_omni>. + + +## 🛠️ Installation + +Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/README.md) + +## Run examples (Qwen3-Omni) + +### Launch the Server + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + +If you want to open async chunking for qwen3-omni, launch the server with command below + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +``` + +If you have custom stage configs file, launch the server with command below +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +Get into the example folder +```bash +cd examples/online_serving/qwen3_omni +``` + +#### Send request via python + +```bash +python openai_chat_completion_client_for_multimodal_generation.py --query-type use_image +``` + +The Python client supports the following command-line arguments: + +- `--query-type` (or `-q`): Query type (default: `use_video`). Options: `text`, `use_audio`, `use_image`, `use_video` +- `--model` (or `-m`): Model name/path (default: `Qwen/Qwen3-Omni-30B-A3B-Instruct`) +- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type is `use_video`, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4` +- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type is `use_image`, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` +- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type is `use_audio`, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` +- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` + + +For example, to use a local video file with custom prompt: + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_video \ + --video-path /path/to/your/video.mp4 \ + --prompt "What are the main activities shown in this video?" +``` + +#### Send request via curl + +```bash +bash run_curl_multimodal_generation.sh use_image +``` + + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` + +## Modality control +You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. + +### Supported modalities + +| Modalities | Output | +|------------|--------| +| `["text"]` | Text only | +| `["audio"]` | Text + Audio | +| `["text", "audio"]` | Text + Audio | +| Not specified | Text + Audio (default) | + +### Using curl + +#### Text only + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["text"] + }' +``` + +#### Text + Audio + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["audio"] + }' +``` + +### Using Python client + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --modalities text +``` + +### Using OpenAI Python SDK + +#### Text only + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen3-Omni-30B-A3B-Instruct", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["text"] +) +print(response.choices[0].message.content) +``` + +#### Text + Audio + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen3-Omni-30B-A3B-Instruct", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["audio"] +) +# Response contains two choices: one with text, one with audio +print(response.choices[0].message.content) # Text response +print(response.choices[1].message.audio) # Audio response +``` + +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --stream +``` + +## Run Local Web UI Demo + +This Web UI demo allows users to interact with the model through a web browser. + +### Running Gradio Demo + +The Gradio demo connects to a vLLM API server. You have two options: + +#### Option 1: One-step Launch Script (Recommended) + +The convenience script launches both the vLLM server and Gradio demo together: + +```bash +./run_gradio_demo.sh --model Qwen/Qwen3-Omni-30B-A3B-Instruct --server-port 8091 --gradio-port 7861 +``` + +This script will: +1. Start the vLLM server in the background +2. Wait for the server to be ready +3. Launch the Gradio demo +4. Handle cleanup when you press Ctrl+C + +The script supports the following arguments: +- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct) +- `--server-port`: Port for vLLM server (default: 8091) +- `--gradio-port`: Port for Gradio demo (default: 7861) +- `--stage-configs-path`: Path to custom stage configs YAML file (optional) +- `--server-host`: Host for vLLM server (default: 0.0.0.0) +- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1) +- `--share`: Share Gradio demo publicly (creates a public link) + +#### Option 2: Manual Launch (Two-Step Process) + +**Step 1: Launch the vLLM API server** + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + +If you have custom stage configs file: +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +**Step 2: Run the Gradio demo** + +In a separate terminal: + +```bash +python gradio_demo.py --model Qwen/Qwen3-Omni-30B-A3B-Instruct --api-base http://localhost:8091/v1 --port 7861 +``` + +Then open `http://localhost:7861/` on your local browser to interact with the web UI. + +The gradio script supports the following arguments: + +- `--model`: Model name/path (should match the server model) +- `--api-base`: Base URL for the vLLM API server (default: http://localhost:8091/v1) +- `--ip`: Host/IP for Gradio server (default: 127.0.0.1) +- `--port`: Port for Gradio server (default: 7861) +- `--share`: Share the Gradio demo publicly (creates a public link) + +## Example materials + +??? abstract "gradio_demo.py" + ``````py + --8<-- "examples/online_serving/qwen3_omni/gradio_demo.py" + `````` +??? abstract "openai_chat_completion_client_for_multimodal_generation.py" + ``````py + --8<-- "examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py" + `````` +??? abstract "qwen3_omni_moe_thinking.yaml" + ``````yaml + --8<-- "examples/online_serving/qwen3_omni/qwen3_omni_moe_thinking.yaml" + `````` +??? abstract "run_curl_multimodal_generation.sh" + ``````sh + --8<-- "examples/online_serving/qwen3_omni/run_curl_multimodal_generation.sh" + `````` +??? abstract "run_gradio_demo.sh" + ``````sh + --8<-- "examples/online_serving/qwen3_omni/run_gradio_demo.sh" + `````` diff --git a/docs/user_guide/examples/online_serving/text_to_image.md b/docs/user_guide/examples/online_serving/text_to_image.md new file mode 100644 index 0000000000000000000000000000000000000000..2de2b4fcb409bbee03953c67d1a084f8c4d4feb7 --- /dev/null +++ b/docs/user_guide/examples/online_serving/text_to_image.md @@ -0,0 +1,181 @@ +# Text-To-Image + +Source <https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/text_to_image>. + + +This example demonstrates how to deploy Qwen-Image model for online image generation service using vLLM-Omni. + +## Start Server + +### Basic Start + +```bash +vllm serve Qwen/Qwen-Image --omni --port 8091 +``` +!!! note + If you encounter Out-of-Memory (OOM) issues or have limited GPU memory, you can enable VAE slicing and tiling to reduce memory usage, --vae-use-slicing --vae-use-tiling + +### Start with Parameters + +Or use the startup script: + +```bash +bash run_server.sh +``` + +## API Calls + +### Method 1: Using curl + +```bash +# Basic text-to-image generation +bash run_curl_text_to_image.sh + +# Or execute directly +curl -s http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "A beautiful landscape painting"} + ], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "true_cfg_scale": 4.0, + "seed": 42 + } + }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png +``` + +### Method 2: Using Python Client + +```bash +python openai_chat_client.py --prompt "A beautiful landscape painting" --output output.png +``` + +### Method 3: Using Gradio Demo + +```bash +python gradio_demo.py +# Visit http://localhost:7860 +``` + +## Request Format + +### Simple Text Generation + +```json +{ + "messages": [ + {"role": "user", "content": "A beautiful landscape painting"} + ] +} +``` + +### Generation with Parameters + +Use `extra_body` to pass generation parameters: + +```json +{ + "messages": [ + {"role": "user", "content": "A beautiful landscape painting"} + ], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "true_cfg_scale": 4.0, + "seed": 42 + } +} +``` + +### Multimodal Input (Text + Structured Content) + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "A beautiful landscape painting"} + ] + } + ] +} +``` + +## Generation Parameters (extra_body) + +| Parameter | Type | Default | Description | +| ------------------------ | ----- | ------- | ------------------------------ | +| `height` | int | None | Image height in pixels | +| `width` | int | None | Image width in pixels | +| `size` | str | None | Image size (e.g., "1024x1024") | +| `num_inference_steps` | int | 50 | Number of denoising steps | +| `true_cfg_scale` | float | 4.0 | Qwen-Image CFG scale | +| `seed` | int | None | Random seed (reproducible) | +| `negative_prompt` | str | None | Negative prompt | +| `num_outputs_per_prompt` | int | 1 | Number of images to generate | +| `--cfg-parallel-size`. | int | 1 | Number of GPUs for CFG parallelism | + +## Response Format + +```json +{ + "id": "chatcmpl-xxx", + "created": 1234567890, + "model": "Qwen/Qwen-Image", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": [{ + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,..." + } + }] + }, + "finish_reason": "stop" + }], + "usage": {...} +} +``` + +## Extract Image + +```bash +# Extract base64 from response and decode to image +cat response.json | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png +``` + +## File Description + +| File | Description | +| --------------------------- | ---------------------------- | +| `run_server.sh` | Server startup script | +| `run_curl_text_to_image.sh` | curl example | +| `openai_chat_client.py` | Python client | +| `gradio_demo.py` | Gradio interactive interface | + +## Example materials + +??? abstract "gradio_demo.py" + ``````py + --8<-- "examples/online_serving/text_to_image/gradio_demo.py" + `````` +??? abstract "openai_chat_client.py" + ``````py + --8<-- "examples/online_serving/text_to_image/openai_chat_client.py" + `````` +??? abstract "run_curl_text_to_image.sh" + ``````sh + --8<-- "examples/online_serving/text_to_image/run_curl_text_to_image.sh" + `````` +??? abstract "run_server.sh" + ``````sh + --8<-- "examples/online_serving/text_to_image/run_server.sh" + `````` diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7bcfb42fa1c646a6aa31ebd28fb3709466b78367 --- /dev/null +++ b/examples/offline_inference/bagel/README.md @@ -0,0 +1,177 @@ +# BAGEL-7B-MoT + +## Set up + +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices. + +Get into the bagel folder + +```bash +cd examples/offline_inference/bagel +``` + +### Modality Control + +BAGEL-7B-MoT supports multiple modality modes. You can control the mode using the `--modality` argument: + +#### Text to Image (text2img) + +- **Pipeline**: Text → Thinker → DiT → VAE Decode → Image +- **Stages Used**: Stage 0 (Thinker) + Stage 1 (DiT) +- **KV Transfer**: Thinker sends KV cache to DiT for conditioned generation + +Generate images from text prompts: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2img \ + --prompts "A cute cat" +``` + +#### Image to Image (img2img) + +- **Pipeline**: Image → VAE Encode → DiT → VAE Decode → New Image +- **Stages Used**: Stage 1 (DiT) only +- **Special**: Bypasses the Thinker stage, direct image-to-image transformation + +Transform images based on text prompts: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality img2img \ + --image-path /path/to/image.jpg \ + --prompts "Let the woman wear a blue dress" +``` + +#### Image to Text (img2text) + +- **Pipeline**: Image → ViT + VAE Encode → Thinker → Text Output +- **Stages Used**: Stage 0 (Thinker) only +- **Special**: Uses both VAE latent encoding AND ViT semantic encoding for comprehensive image understanding + +Generate text descriptions from images: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality img2text \ + --image-path /path/to/image.jpg \ + --prompts "Describe this image in detail" +``` + +#### Text to Text (text2text) + +- **Pipeline**: Text → Thinker → Text Output +- **Stages Used**: Stage 0 (Thinker) only +- **Special**: No visual components involved, operates as pure language model + +Pure text generation: + +```bash +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2text \ + --prompts "What is the capital of France?" + +# You can load prompts from a text file (one prompt per line): +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2text \ + --txt-prompts /path/to/prompts.txt +``` + +### Inference Steps + +Control the number of inference steps for image generation: + +```bash +# You can adjust steps to 100 to improve image quality +python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \ + --modality text2img \ + --steps 50 \ + --prompts "A cute cat" +``` + +### Key arguments + +BAGEL-7B-MoT supports **multiple modality modes** for different use cases. + +The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml) + +#### 📌 Command Line Arguments (end2end.py) + +| Argument | Type | Default | Description | +| :--------------------- | :----- | :---------------------------- | :----------------------------------------------------------- | +| `--model` | string | `ByteDance-Seed/BAGEL-7B-MoT` | Model path or name | +| `--modality` | choice | `text2img` | Modality mode: `text2img`, `img2img`, `img2text`, `text2text` | +| `--prompts` | list | `None` | Input text prompts directly | +| `--txt-prompts` | string | `None` | Path to txt file with one prompt per line | +| `--image-path` | string | `None` | Input image path (for `img2img`/`img2text`) | +| `--steps` | int | `50` | Number of inference steps | +| `--stage-configs-path` | string | `None` | Custom stage config file path | +| `--worker-backend` | choice | `process` | Worker backend: `process` or `ray` | +| `--ray-address` | string | `None` | Ray cluster address | +| `--enable-stats` | flag | `False` | Enable statistics logging | +| `--init-sleep-seconds` | int | `20` | Initialization sleep time | +| `--batch-timeout` | int | `5` | Batch timeout | +| `--init-timeout` | int | `300` | Initialization timeout | + +------ + +#### ⚙️ Stage Configuration Parameters (bagel.yaml) + + **Stage 0 - Thinker (LLM Stage)** + +| Parameter | Value | Description | +| :------------------------------- | :------------------------------ | :----------------------- | +| `stage_type` | `llm` | Stage type | +| `devices` | `"0"` | GPU device ID | +| `max_batch_size` | `1` | Maximum batch size | +| `model_stage` | `thinker` | Model stage identifier | +| `model_arch` | `BagelForConditionalGeneration` | Model architecture | +| `gpu_memory_utilization` | `0.4` | GPU memory utilization | +| `tensor_parallel_size` | `1` | Tensor parallel size | +| `max_num_batched_tokens` | `32768` | Maximum batched tokens | +| `omni_kv_config.need_send_cache` | `true` | Whether to send KV cache | + +------ + +**Stage 1 - DiT (Diffusion Stage)** + +| Parameter | Value | Description | +| :------------------------------- | :---------- | :-------------------------- | +| `stage_type` | `diffusion` | Stage type | +| `devices` | `"0"` | GPU device ID | +| `max_batch_size` | `1` | Maximum batch size | +| `model_stage` | `dit` | Model stage identifier | +| `gpu_memory_utilization` | `0.4` | GPU memory utilization | +| `omni_kv_config.need_recv_cache` | `true` | Whether to receive KV cache | +| `engine_input_source` | `[0]` | Input source from Stage 0 | + +------ + +#### 🔗 Runtime Configuration + +| Parameter | Value | Description | +| :-------------------- | :------ | :------------------------------- | +| `window_size` | `-1` | Window size (-1 means unlimited) | +| `max_inflight` | `1` | Maximum inflight requests | +| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) | + +## FAQ + +- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. + +```bash +sudo apt update +sudo apt install ffmpeg +``` + +- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. + +| Stage | VRAM | +| :------------------ | :--------------------------- | +| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** | +| Stage-1 (DiT) | **26.50 GiB** | +| Total | **~42 GiB + KV Cache** | diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py new file mode 100644 index 0000000000000000000000000000000000000000..397fd333ec0da840bb8c31c57b8767adbd0cf871 --- /dev/null +++ b/examples/offline_inference/bagel/end2end.py @@ -0,0 +1,189 @@ +import argparse +import os +from typing import cast + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + default="ByteDance-Seed/BAGEL-7B-MoT", + help="Path to merged model directory.", + ) + parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.") + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one prompt per line (preferred).", + ) + parser.add_argument("--prompt_type", default="text", choices=["text"]) + + parser.add_argument( + "--modality", + default="text2img", + choices=["text2img", "img2img", "img2text", "text2text"], + help="Modality mode to control stage execution.", + ) + + parser.add_argument( + "--image-path", + type=str, + default=None, + help="Path to input image for img2img.", + ) + + # OmniLLM init args + parser.add_argument("--enable-stats", action="store_true", default=False) + parser.add_argument("--init-sleep-seconds", type=int, default=20) + parser.add_argument("--batch-timeout", type=int, default=5) + parser.add_argument("--init-timeout", type=int, default=300) + parser.add_argument("--shm-threshold-bytes", type=int, default=65536) + parser.add_argument("--worker-backend", type=str, default="process", choices=["process", "ray"]) + parser.add_argument("--ray-address", type=str, default=None) + parser.add_argument("--stage-configs-path", type=str, default=None) + parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + model_name = args.model + prompts: list[OmniPromptType] = [] + try: + # Preferred: load from txt file (one prompt per line) + if getattr(args, "txt_prompts", None) and args.prompt_type == "text": + with open(args.txt_prompts, encoding="utf-8") as f: + lines = [ln.strip() for ln in f.readlines()] + prompts = [ln for ln in lines if ln != ""] + print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}") + else: + prompts = args.prompts + except Exception as e: + print(f"[Error] Failed to load prompts: {e}") + raise + + if not prompts: + # Default prompt for text2img test if none provided + prompts = ["<|im_start|>A cute cat<|im_end|>"] + print(f"[Info] No prompts provided, using default: {prompts}") + omni_outputs = [] + + from PIL import Image + + if args.modality == "img2img": + from PIL import Image + + from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion + + print("[Info] Running in img2img mode (Stage 1 only)") + client = OmniDiffusion(model=model_name) + + if args.image_path: + if os.path.exists(args.image_path): + loaded_image = Image.open(args.image_path).convert("RGB") + prompts = [ + { + "prompt": cast(str, p), + "multi_modal_data": {"image": loaded_image}, + } + for p in prompts + ] + else: + print(f"[Warning] Image path {args.image_path} does not exist.") + + result = client.generate( + prompts, + OmniDiffusionSamplingParams( + seed=52, + need_kv_receive=False, + num_inference_steps=args.steps, + ), + ) + + # Ensure result is a list for iteration + if not isinstance(result, list): + omni_outputs = [result] + else: + omni_outputs = result + + else: + from vllm_omni.entrypoints.omni import Omni + + omni_kwargs = {} + if args.stage_configs_path: + omni_kwargs["stage_configs_path"] = args.stage_configs_path + + omni_kwargs.update( + { + "log_stats": args.enable_stats, + "init_sleep_seconds": args.init_sleep_seconds, + "batch_timeout": args.batch_timeout, + "init_timeout": args.init_timeout, + "shm_threshold_bytes": args.shm_threshold_bytes, + "worker_backend": args.worker_backend, + "ray_address": args.ray_address, + } + ) + + omni = Omni(model=model_name, **omni_kwargs) + + formatted_prompts = [] + for p in args.prompts: + if args.modality == "img2text": + if args.image_path: + loaded_image = Image.open(args.image_path).convert("RGB") + final_prompt_text = f"<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n" + prompt_dict = { + "prompt": final_prompt_text, + "multi_modal_data": {"image": loaded_image}, + "modalities": ["text"], + } + formatted_prompts.append(prompt_dict) + elif args.modality == "text2text": + final_prompt_text = f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n" + prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]} + formatted_prompts.append(prompt_dict) + else: + # text2img + final_prompt_text = f"<|im_start|>{p}<|im_end|>" + prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]} + formatted_prompts.append(prompt_dict) + + params_list = omni.default_sampling_params_list + if args.modality == "text2img": + params_list[0].max_tokens = 1 # type: ignore # The first stage is a SamplingParam (vllm) + if len(params_list) > 1: + params_list[1].num_inference_steps = args.steps # type: ignore # The second stage is an OmniDiffusionSamplingParam + + omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list)) + + for i, req_output in enumerate(omni_outputs): + images = getattr(req_output, "images", None) + if not images and hasattr(req_output, "output"): + if isinstance(req_output.output, list): + images = req_output.output + else: + images = [req_output.output] + + if images: + for j, img in enumerate(images): + img.save(f"output_{i}_{j}.png") + + if hasattr(req_output, "request_output") and req_output.request_output: + for stage_out in req_output.request_output: + if hasattr(stage_out, "images") and stage_out.images: + for k, img in enumerate(stage_out.images): + save_path = f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png" + img.save(save_path) + print(f"[Info] Saved stage output image to {save_path}") + + print(omni_outputs) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..8f330e09d20194591818ae90ee69288164379e7f --- /dev/null +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -0,0 +1,492 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example script for image editing with Qwen-Image-Edit. + +Usage (single image): + python image_edit.py \ + --image input.png \ + --prompt "Let this mascot dance under the moon, surrounded by floating stars and poetic bubbles such as 'Be Kind'" \ + --output output_image_edit.png \ + --num_inference_steps 50 \ + --cfg_scale 4.0 \ + --guidance_scale 1.0 + +Usage (multiple images): + python image_edit.py \ + --image input1.png input2.png input3.png \ + --prompt "Combine these images into a single scene" \ + --output output_image_edit.png \ + --num_inference_steps 50 \ + --cfg_scale 4.0 \ + --guidance_scale 1.0 + +Usage (with cache-dit acceleration): + python image_edit.py \ + --image input.png \ + --prompt "Edit description" \ + --cache_backend cache_dit \ + --cache_dit_max_continuous_cached_steps 3 \ + --cache_dit_residual_diff_threshold 0.24 \ + --cache_dit_enable_taylorseer + +Usage (with tea_cache acceleration): + python image_edit.py \ + --image input.png \ + --prompt "Edit description" \ + --cache_backend tea_cache \ + --tea_cache_rel_l1_thresh 0.25 + +Usage (layered): + python image_edit.py \ + --model "Qwen/Qwen-Image-Layered" \ + --image input.png \ + --prompt "" \ + --output "layered" \ + --num_inference_steps 50 \ + --cfg_scale 4.0 \ + --layers 4 \ + --color-format "RGBA" + +Usage (with CFG Parallel): + python image_edit.py \ + --image input.png \ + --prompt "Edit description" \ + --cfg_parallel_size 2 \ + --num_inference_steps 50 \ + --cfg_scale 4.0 + +Usage (disable torch.compile): + python image_edit.py \ + --image input.png \ + --prompt "Edit description" \ + --enforce_eager \ + --num_inference_steps 50 \ + --cfg_scale 4.0 + +For more options, run: + python image_edit.py --help +""" + +import argparse +import os +import time +from pathlib import Path + +import torch +from PIL import Image + +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Edit an image with Qwen-Image-Edit.") + parser.add_argument( + "--model", + default="Qwen/Qwen-Image-Edit", + help=( + "Diffusion model name or local path. " + "For multiple image inputs, use Qwen/Qwen-Image-Edit-2509 or Qwen/Qwen-Image-Edit-2511" + "which supports QwenImageEditPlusPipeline." + ), + ) + parser.add_argument( + "--image", + type=str, + nargs="+", + required=True, + help="Path(s) to input image file(s) (PNG, JPG, etc.). Can specify multiple images.", + ) + parser.add_argument( + "--prompt", + type=str, + required=True, + help="Text prompt describing the edit to make to the image.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed for deterministic results.", + ) + parser.add_argument( + "--cfg_scale", + type=float, + default=4.0, + help=( + "True classifier-free guidance scale (default: 4.0). Guidance scale as defined in Classifier-Free " + "Diffusion Guidance. Classifier-free guidance is enabled by setting cfg_scale > 1 and providing " + "a negative_prompt. Higher guidance scale encourages images closely linked to the text prompt, " + "usually at the expense of lower image quality." + ), + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help=( + "Guidance scale for guidance-distilled models (default: 1.0, disabled). " + "Unlike classifier-free guidance (--cfg_scale), guidance-distilled models take the guidance scale " + "directly as an input parameter. Enabled when guidance_scale > 1. Ignored when not using guidance-distilled models." + ), + ) + parser.add_argument( + "--output", + type=str, + default="output_image_edit.png", + help=("Path to save the edited image (PNG). Or prefix for Qwen-Image-Layered model save images(PNG)."), + ) + parser.add_argument( + "--num_outputs_per_prompt", + type=int, + default=1, + help="Number of images to generate for the given prompt.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Number of denoising steps for the diffusion sampler.", + ) + parser.add_argument( + "--cache_backend", + type=str, + default=None, + choices=["cache_dit", "tea_cache"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). " + "Default: None (no cache acceleration)." + ), + ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=1, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", + ) + parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.") + parser.add_argument( + "--resolution", + type=int, + default=640, + help="Bucket in (640, 1024) to determine the condition and output resolution", + ) + + parser.add_argument( + "--color-format", + type=str, + default="RGB", + help="For Qwen-Image-Layered, set to RGBA.", + ) + + # Cache-DiT specific parameters + parser.add_argument( + "--cache_dit_fn_compute_blocks", + type=int, + default=1, + help="[cache-dit] Number of forward compute blocks. Optimized for single-transformer models.", + ) + parser.add_argument( + "--cache_dit_bn_compute_blocks", + type=int, + default=0, + help="[cache-dit] Number of backward compute blocks.", + ) + parser.add_argument( + "--cache_dit_max_warmup_steps", + type=int, + default=4, + help="[cache-dit] Maximum warmup steps (works for few-step models).", + ) + parser.add_argument( + "--cache_dit_residual_diff_threshold", + type=float, + default=0.24, + help="[cache-dit] Residual diff threshold. Higher values enable more aggressive caching.", + ) + parser.add_argument( + "--cache_dit_max_continuous_cached_steps", + type=int, + default=3, + help="[cache-dit] Maximum continuous cached steps to prevent precision degradation.", + ) + parser.add_argument( + "--cache_dit_enable_taylorseer", + action="store_true", + default=False, + help="[cache-dit] Enable TaylorSeer acceleration (not suitable for few-step models).", + ) + parser.add_argument( + "--cache_dit_taylorseer_order", + type=int, + default=1, + help="[cache-dit] TaylorSeer polynomial order.", + ) + parser.add_argument( + "--cache_dit_scm_steps_mask_policy", + type=str, + default=None, + choices=[None, "slow", "medium", "fast", "ultra"], + help="[cache-dit] SCM mask policy: None (disabled), slow, medium, fast, ultra.", + ) + parser.add_argument( + "--cache_dit_scm_steps_policy", + type=str, + default="dynamic", + choices=["dynamic", "static"], + help="[cache-dit] SCM steps policy: dynamic or static.", + ) + + # TeaCache specific parameters + parser.add_argument( + "--tea_cache_rel_l1_thresh", + type=float, + default=0.2, + help="[tea_cache] Threshold for accumulated relative L1 distance.", + ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) + parser.add_argument( + "--vae_use_slicing", + action="store_true", + help="Enable VAE slicing for memory optimization.", + ) + parser.add_argument( + "--vae_use_tiling", + action="store_true", + help="Enable VAE tiling for memory optimization.", + ) + parser.add_argument( + "--enable-cpu-offload", + action="store_true", + help="Enable CPU offloading for diffusion models.", + ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Validate input images exist and load them + input_images = [] + for image_path in args.image: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Input image not found: {image_path}") + + img = Image.open(image_path).convert(args.color_format) + input_images.append(img) + + # Use single image or list based on number of inputs + if len(input_images) == 1: + input_image = input_images[0] + else: + input_image = input_images + + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + + parallel_config = DiffusionParallelConfig( + ulysses_degree=args.ulysses_degree, + ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, + ) + + # Configure cache based on backend type + cache_config = None + if args.cache_backend == "cache_dit": + # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer + cache_config = { + "Fn_compute_blocks": args.cache_dit_fn_compute_blocks, + "Bn_compute_blocks": args.cache_dit_bn_compute_blocks, + "max_warmup_steps": args.cache_dit_max_warmup_steps, + "residual_diff_threshold": args.cache_dit_residual_diff_threshold, + "max_continuous_cached_steps": args.cache_dit_max_continuous_cached_steps, + "enable_taylorseer": args.cache_dit_enable_taylorseer, + "taylorseer_order": args.cache_dit_taylorseer_order, + "scm_steps_mask_policy": args.cache_dit_scm_steps_mask_policy, + "scm_steps_policy": args.cache_dit_scm_steps_policy, + } + elif args.cache_backend == "tea_cache": + # TeaCache configuration + cache_config = { + "rel_l1_thresh": args.tea_cache_rel_l1_thresh, + # Note: coefficients will use model-specific defaults based on model_type + } + + # Initialize Omni with appropriate pipeline + omni = Omni( + model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, + vae_use_slicing=args.vae_use_slicing, + vae_use_tiling=args.vae_use_tiling, + cache_backend=args.cache_backend, + cache_config=cache_config, + parallel_config=parallel_config, + enforce_eager=args.enforce_eager, + enable_cpu_offload=args.enable_cpu_offload, + ) + print("Pipeline loaded") + + # Check if profiling is requested via environment variable + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + + # Time profiling for generation + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + if isinstance(input_image, list): + print(f" Number of input images: {len(input_image)}") + for idx, img in enumerate(input_image): + print(f" Image {idx + 1} size: {img.size}") + else: + print(f" Input image size: {input_image.size}") + print( + f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}" + ) + print(f"{'=' * 60}\n") + + generation_start = time.perf_counter() + + if profiler_enabled: + print("[Profiler] Starting profiling...") + omni.start_profile() + + # Generate edited image + outputs = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": {"image": input_image}, + }, + OmniDiffusionSamplingParams( + generator=generator, + true_cfg_scale=args.cfg_scale, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + num_outputs_per_prompt=args.num_outputs_per_prompt, + layers=args.layers, + resolution=args.resolution, + ), + ) + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + # Print profiling results + print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + if profile_results and isinstance(profile_results, dict): + traces = profile_results.get("traces", []) + print("\n" + "=" * 60) + print("PROFILING RESULTS:") + for rank, trace in enumerate(traces): + print(f"\nRank {rank}:") + if trace: + print(f" • Trace: {trace}") + if not traces: + print(" No traces collected.") + print("=" * 60) + else: + print("[Profiler] No valid profiling data returned.") + + if not outputs: + raise ValueError("No output generated from omni.generate()") + + # Extract images from OmniRequestOutput + # omni.generate() returns list[OmniRequestOutput], extract images from request_output[0].images + first_output = outputs[0] + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out.images + if not images: + raise ValueError("No images found in request_output") + + # Save output image(s) + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".png" + stem = output_path.stem or "output_image_edit" + + # Handle layered output (each image may be a list of layers) + if args.num_outputs_per_prompt <= 1: + img = images[0] + # Check if this is a layered output (list of images) + if isinstance(img, list): + for sub_idx, sub_img in enumerate(img): + save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}" + sub_img.save(save_path) + print(f"Saved edited image to {os.path.abspath(save_path)}") + else: + img.save(output_path) + print(f"Saved edited image to {os.path.abspath(output_path)}") + else: + for idx, img in enumerate(images): + # Check if this is a layered output (list of images) + if isinstance(img, list): + for sub_idx, sub_img in enumerate(img): + save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}" + sub_img.save(save_path) + print(f"Saved edited image to {os.path.abspath(save_path)}") + else: + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + img.save(save_path) + print(f"Saved edited image to {os.path.abspath(save_path)}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/image_to_image/image_to_image.md b/examples/offline_inference/image_to_image/image_to_image.md new file mode 100644 index 0000000000000000000000000000000000000000..d0986d6ee7fa0d2cc134f3524c5a68dbac7d2af4 --- /dev/null +++ b/examples/offline_inference/image_to_image/image_to_image.md @@ -0,0 +1,55 @@ +# Image-To-Image + +This example edits an input image with `Qwen/Qwen-Image-Edit` using the `image_edit.py` CLI. + +## Local CLI Usage + +### Single Image Editing + +Download the example image: + +```bash +wget https://vllm-public-assets.s3.us-west-2.amazonaws.com/omni-assets/qwen-bear.png +``` + +Then run: + +```bash +python image_edit.py \ + --image qwen-bear.png \ + --prompt "Let this mascot dance under the moon, surrounded by floating stars and poetic bubbles such as 'Be Kind'" \ + --output output_image_edit.png \ + --num_inference_steps 50 \ + --cfg_scale 4.0 +``` + +### Multiple Image Editing (Qwen-Image-Edit-2509) + +For multiple image inputs, use `Qwen/Qwen-Image-Edit-2509` or `Qwen/Qwen-Image-Edit-2511`: + +```bash +python image_edit.py \ + --model Qwen/Qwen-Image-Edit-2509 \ + --image img1.png img2.png \ + --prompt "Combine these images into a single scene" \ + --output output_image_edit.png \ + --num_inference_steps 50 \ + --cfg_scale 4.0 \ + --guidance_scale 1.0 +``` + +Key arguments: + +- `--model`: model name or path. Use `Qwen/Qwen-Image-Edit-2509` or later for multiple image support. +- `--image`: path(s) to the source image(s) (PNG/JPG, converted to RGB). Can specify multiple images. +- `--prompt` / `--negative_prompt`: text description (string). +- `--cfg_scale`: true classifier-free guidance scale (default: 4.0). Classifier-free guidance is enabled by setting cfg_scale > 1 and providing a negative_prompt. Higher guidance scale encourages images closely linked to the text prompt, usually at the expense of lower image quality. +- `--guidance_scale`: guidance scale for guidance-distilled models (default: 1.0, disabled). Unlike classifier-free guidance (--cfg_scale), guidance-distilled models take the guidance scale directly as an input parameter. Enabled when guidance_scale > 1. Ignored when not using guidance-distilled models. +- `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). +- `--output`: path to save the generated PNG. +- `--vae_use_slicing`: enable VAE slicing for memory optimization. +- `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--enable-cpu-offload`: enable CPU offloading for diffusion models. + +> ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/image_to_image/run_qwen_image_edit_2511.sh b/examples/offline_inference/image_to_image/run_qwen_image_edit_2511.sh new file mode 100644 index 0000000000000000000000000000000000000000..ac230ed8ac401473f3fd052f4bf65750ea05c248 --- /dev/null +++ b/examples/offline_inference/image_to_image/run_qwen_image_edit_2511.sh @@ -0,0 +1,8 @@ +python image_edit.py \ + --model Qwen/Qwen-Image-Edit-2511 \ + --image qwen_bear.png \ + --prompt "Add a white art board written with colorful text 'vLLM-Omni' on grassland. Add a paintbrush in the bear's hands. position the bear standing in front of the art board as if painting" \ + --output output_image_edit.png \ + --num_inference_steps 50 \ + --cfg_scale 4.0 \ + --cache_backend cache_dit \ diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a1355dab69e59a13d1d10ddae13b6c31e39a9b8f --- /dev/null +++ b/examples/offline_inference/image_to_video/README.md @@ -0,0 +1,62 @@ +# Image-To-Video + +This example demonstrates how to generate videos from images using Wan2.2 Image-to-Video models with vLLM-Omni's offline inference API. + +## Local CLI Usage + +### Wan2.2-I2V-A14B-Diffusers (MoE) +```bash +python image_to_video.py \ + --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --image input.png \ + --prompt "A cat playing with yarn, smooth motion" \ + --negative_prompt "<optional quality filter>" \ + --height 480 \ + --width 832 \ + --num_frames 48 \ + --guidance_scale 5.0 \ + --guidance_scale_high 6.0 \ + --num_inference_steps 40 \ + --boundary_ratio 0.875 \ + --flow_shift 12.0 \ + --fps 16 \ + --output i2v_output.mp4 +``` + +### Wan2.2-TI2V-5B-Diffusers (Unified) +```bash +python image_to_video.py \ + --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --image input.png \ + --prompt "A cat playing with yarn, smooth motion" \ + --negative_prompt "<optional quality filter>" \ + --height 480 \ + --width 832 \ + --num_frames 48 \ + --guidance_scale 4.0 \ + --num_inference_steps 40 \ + --flow_shift 12.0 \ + --fps 16 \ + --output i2v_output.mp4 +``` + +Key arguments: + +- `--model`: Model ID (I2V-A14B for MoE, TI2V-5B for unified T2V+I2V). +- `--image`: Path to input image (required). +- `--prompt`: Text description of desired motion/animation. +- `--height/--width`: Output resolution (auto-calculated from image if not set). Dimensions should be multiples of 16. +- `--num_frames`: Number of frames (default 81). +- `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high-noise stages for MoE). +- `--negative_prompt`: Optional list of artifacts to suppress. +- `--boundary_ratio`: Boundary split ratio for two-stage MoE models. +- `--flow_shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--num_inference_steps`: Number of denoising steps (default 50). +- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). +- `--output`: Path to save the generated video. +- `--vae_use_slicing`: Enable VAE slicing for memory optimization. +- `--vae_use_tiling`: Enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--enable-cpu-offload`: enable CPU offloading for diffusion models. + +> ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8d3991559d3c5fb97494de3cbcdec972221167 --- /dev/null +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Image-to-Video generation example using Wan2.2 I2V/TI2V models. + +Supports: +- Wan2.2-I2V-A14B-Diffusers: MoE model with CLIP image encoder +- Wan2.2-TI2V-5B-Diffusers: Unified T2V+I2V model (dense 5B) + +Usage: + # I2V-A14B (MoE) + python image_to_video.py --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --image input.jpg --prompt "A cat playing with yarn" + + # TI2V-5B (unified) + python image_to_video.py --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --image input.jpg --prompt "A cat playing with yarn" +""" + +import argparse +import os +from pathlib import Path + +import numpy as np +import PIL.Image +import torch + +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate a video from an image with Wan2.2 I2V/TI2V.") + parser.add_argument( + "--model", + default="Wan-AI/Wan2.2-I2V-A14B-Diffusers", + help="Diffusers Wan2.2 I2V model ID or local path.", + ) + parser.add_argument("--image", required=True, help="Path to input image.") + parser.add_argument("--prompt", default="", help="Text prompt describing the desired motion.") + parser.add_argument("--negative_prompt", default="", help="Negative prompt.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument("--guidance_scale", type=float, default=5.0, help="CFG scale.") + parser.add_argument( + "--guidance_scale_high", type=float, default=None, help="Optional separate CFG for high-noise (MoE only)." + ) + parser.add_argument( + "--height", type=int, default=None, help="Video height (auto-calculated from image if not set)." + ) + parser.add_argument("--width", type=int, default=None, help="Video width (auto-calculated from image if not set).") + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames.") + parser.add_argument("--num_inference_steps", type=int, default=50, help="Sampling steps.") + parser.add_argument("--boundary_ratio", type=float, default=0.875, help="Boundary split ratio for MoE models.") + parser.add_argument( + "--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." + ) + parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") + parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.") + parser.add_argument( + "--vae_use_slicing", + action="store_true", + help="Enable VAE slicing for memory optimization.", + ) + parser.add_argument( + "--vae_use_tiling", + action="store_true", + help="Enable VAE tiling for memory optimization.", + ) + parser.add_argument( + "--enable-cpu-offload", + action="store_true", + help="Enable CPU offloading for diffusion models.", + ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) + return parser.parse_args() + + +def calculate_dimensions(image: PIL.Image.Image, max_area: int = 480 * 832) -> tuple[int, int]: + """Calculate output dimensions maintaining aspect ratio.""" + aspect_ratio = image.height / image.width + mod_value = 16 # Must be divisible by 16 + + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + return height, width + + +def main(): + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + + # Load input image + image = PIL.Image.open(args.image).convert("RGB") + + # Calculate dimensions if not provided + height = args.height + width = args.width + if height is None or width is None: + # Default to 480P area for I2V + calc_height, calc_width = calculate_dimensions(image, max_area=480 * 832) + height = height or calc_height + width = width or calc_width + + # Resize image to target dimensions + image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + # Check if profiling is requested via environment variable + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + parallel_config = DiffusionParallelConfig( + cfg_parallel_size=args.cfg_parallel_size, + ) + omni = Omni( + model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, + vae_use_slicing=args.vae_use_slicing, + vae_use_tiling=args.vae_use_tiling, + boundary_ratio=args.boundary_ratio, + flow_shift=args.flow_shift, + enable_cpu_offload=args.enable_cpu_offload, + parallel_config=parallel_config, + enforce_eager=args.enforce_eager, + ) + + if profiler_enabled: + print("[Profiler] Starting profiling...") + omni.start_profile() + + # Print generation configuration + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Frames: {args.num_frames}") + print(f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}") + print(f" Video size: {args.width}x{args.height}") + print(f"{'=' * 60}\n") + + # omni.generate() returns Generator[OmniRequestOutput, None, None] + frames = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": {"image": image}, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=generator, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_high, + num_inference_steps=args.num_inference_steps, + num_frames=args.num_frames, + ), + ) + + # Extract video frames from OmniRequestOutput + if isinstance(frames, list) and len(frames) > 0: + first_item = frames[0] + + # Check if it's an OmniRequestOutput + if hasattr(first_item, "final_output_type"): + if first_item.final_output_type != "image": + raise ValueError( + f"Unexpected output type '{first_item.final_output_type}', expected 'image' for video generation." + ) + + # Pipeline mode: extract from nested request_output + if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: + if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: + inner_output = first_item.request_output[0] + if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): + frames = inner_output.images[0] if inner_output.images else None + if frames is None: + raise ValueError("No video frames found in output.") + # Diffusion mode: use direct images field + elif hasattr(first_item, "images") and first_item.images: + frames = first_item.images + else: + raise ValueError("No video frames found in OmniRequestOutput.") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + try: + from diffusers.utils import export_to_video + except ImportError: + raise ImportError("diffusers is required for export_to_video.") + + # frames may be np.ndarray (preferred) or torch.Tensor + # export_to_video expects a list of frames with values in [0, 1] + if isinstance(frames, torch.Tensor): + video_tensor = frames.detach().cpu() + if video_tensor.dim() == 5: + # [B, C, F, H, W] or [B, F, H, W, C] + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + video_tensor = video_tensor.permute(1, 2, 3, 0) + # If float, assume [-1,1] and normalize to [0,1] + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + video_array = video_tensor.float().numpy() + else: + video_array = frames + if hasattr(video_array, "shape") and video_array.ndim == 5: + video_array = video_array[0] + + # Convert 4D array (frames, H, W, C) to list of frames for export_to_video + if isinstance(video_array, np.ndarray) and video_array.ndim == 4: + video_array = list(video_array) + + export_to_video(video_array, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + if profile_results and isinstance(profile_results, dict): + traces = profile_results.get("traces", []) + print("\n" + "=" * 60) + print("PROFILING RESULTS:") + for rank, trace in enumerate(traces): + print(f"\nRank {rank}:") + if trace: + print(f" • Trace: {trace}") + if not traces: + print(" No traces collected.") + print("=" * 60) + else: + print("[Profiler] No valid profiling data returned.") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/lora_inference/README.md b/examples/offline_inference/lora_inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b0b195f6e6f85400487c3f652c7a1da46b587bca --- /dev/null +++ b/examples/offline_inference/lora_inference/README.md @@ -0,0 +1,98 @@ +# LoRA Inference Examples + +This directory contains examples for using LoRA (Low-Rank Adaptation) adapters with vLLM-omni diffusion models for offline inference. +The example uses the `stabilityai/stable-diffusion-3.5-medium` as the default model, but you can replace it with other models in vLLM-omni. + +## Overview + +Similar to vLLM, vLLM-omni uses a unified LoRA handling mechanism: + +- **Pre-loaded LoRA**: Loaded at initialization via `--lora-path` (pre-loaded into cache) +- **Per-request LoRA**: Loaded on-demand. In the example, the LoRA is loaded via `--lora-request-path` in each request + +Both approaches use the same underlying mechanism - all LoRA adapters are handled uniformly through `set_active_adapter()`. If no LoRA request is provided in a request, all adapters are deactivated. + +## Usage + +### Pre-loaded LoRA (via --lora-path) + +Load a LoRA adapter at initialization. This adapter is pre-loaded into the cache and can be activated by requests: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_preloaded.png +``` + +**Note**: When using `--lora-path`, the adapter is loaded at init time with a stable ID derived from the adapter path. This example activates it automatically for the request. + +### Per-request LoRA (via --lora-request-path) + +Load a LoRA adapter on-demand for each request: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-request-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_per_request.png +``` + +### No LoRA + +If no LoRA request is provided, we will use the base model without any LoRA adapters: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_no_lora.png +``` + +## Parameters + +### LoRA Parameters + +- `--lora-path`: Path to LoRA adapter folder to pre-load at initialization (loads into cache with a stable ID derived from the path) +- `--lora-request-path`: Path to LoRA adapter folder for per-request loading +- `--lora-request-id`: Integer ID for the LoRA adapter (optional). If not provided and `--lora-request-path` is set, will derive a stable ID from the path. +- `--lora-scale`: Scale factor for LoRA weights (default: 1.0). Higher values increase the influence of the LoRA adapter. + +### Standard Parameters + +- `--prompt`: Text prompt for image generation (required) +- `--seed`: Random seed for reproducibility (default: 42) +- `--height`: Image height in pixels (default: 1024) +- `--width`: Image width in pixels (default: 1024) +- `--num_inference_steps`: Number of denoising steps (default: 50) +- `--output`: Output file path (default: `lora_output.png`) + +## How LoRA Works + +All LoRA adapters are handled uniformly: + +1. **Initialization**: If `--lora-path` is provided, the adapter is loaded into cache with a stable ID derived from the adapter path +2. **Per-request**: If `--lora-request-path` is provided, the adapter is loaded/activated for that request +3. **No LoRA**: If no LoRA request is provided (`req.lora_request` is None), all adapters are deactivated + +The system uses LRU cache management - adapters are cached and evicted when the cache is full (unless pinned). + +## LoRA Adapter Format + +LoRA adapters must be in PEFT (Parameter-Efficient Fine-Tuning) format. A typical LoRA adapter directory structure: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` diff --git a/examples/offline_inference/lora_inference/lora_inference.py b/examples/offline_inference/lora_inference/lora_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4299edb84fdd8585c4cf1bc80037b45d93d547 --- /dev/null +++ b/examples/offline_inference/lora_inference/lora_inference.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +from pathlib import Path + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate images with LoRA adapters.") + parser.add_argument("--model", default="stabilityai/stable-diffusion-3.5-medium", help="Model name or path.") + parser.add_argument("--prompt", required=True, help="Text prompt for image generation.") + parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.") + parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") + parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Number of denoising steps for the diffusion sampler.", + ) + parser.add_argument( + "--output", + type=str, + default="lora_output.png", + help="Path to save the generated image (PNG).", + ) + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to LoRA adapter folder to pre-load at initialization (PEFT format). " + "Note: pre-loading populates the cache; you still need to pass a lora_request to activate it.", + ) + parser.add_argument( + "--lora-request-path", + type=str, + default=None, + help="Path to LoRA adapter folder for per-request activation (dynamic LoRA). " + "If --lora-request-id is not provided, a stable ID will be derived from this path.", + ) + parser.add_argument( + "--lora-request-id", + type=int, + default=None, + help="Integer ID for the LoRA adapter (for dynamic LoRA). " + "If not provided and --lora-request-path is set, will derive a stable ID from the path.", + ) + parser.add_argument( + "--lora-scale", + type=float, + default=1.0, + help="Scale factor for LoRA weights (default: 1.0).", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + model = args.model + + omni_kwargs = {} + + if args.lora_path: + omni_kwargs["lora_path"] = args.lora_path + print(f"Using static LoRA from: {args.lora_path}") + + omni = Omni(model=model, **omni_kwargs) + + lora_request = None + if args.lora_request_path: + if args.lora_request_id is None: + lora_request_id = stable_lora_int_id(args.lora_request_path) + else: + lora_request_id = args.lora_request_id + + lora_name = Path(args.lora_request_path).stem + lora_request = LoRARequest( + lora_name=lora_name, + lora_int_id=lora_request_id, + lora_path=args.lora_request_path, + ) + print(f"Using per-request LoRA: name={lora_name}, id={lora_request_id}, scale={args.lora_scale}") + elif args.lora_path: + # pre-loaded LoRA + lora_request_id = stable_lora_int_id(args.lora_path) + lora_request = LoRARequest( + lora_name="preloaded", + lora_int_id=lora_request_id, + lora_path=args.lora_path, + ) + print(f"Activating pre-loaded LoRA: id={lora_request_id}, scale={args.lora_scale}") + + sampling_params = OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + ) + + if lora_request: + sampling_params.lora_request = lora_request + sampling_params.lora_scale = args.lora_scale + + outputs = omni.generate(args.prompt, sampling_params) + + if not outputs or len(outputs) == 0: + raise ValueError("No output generated from omni.generate()") + + if isinstance(outputs, list): + first_output = outputs[0] + else: + first_output = outputs + + images = None + if hasattr(first_output, "images") and first_output.images: + images = first_output.images + elif hasattr(first_output, "request_output") and first_output.request_output: + req_out = first_output.request_output + if isinstance(req_out, list) and len(req_out) > 0: + req_out = req_out[0] + if hasattr(req_out, "images") and req_out.images: + images = req_out.images + + if not images: + raise ValueError("No images found in request_output") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".png" + stem = output_path.stem or "lora_output" + if len(images) <= 1: + images[0].save(output_path) + print(f"Saved generated image to {output_path}") + else: + for idx, img in enumerate(images): + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + img.save(save_path) + print(f"Saved generated image to {save_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md new file mode 100644 index 0000000000000000000000000000000000000000..20740a0da02034ebaa80b95de2a7c380120d2ca8 --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -0,0 +1,70 @@ +# Qwen2.5-Omni + +## Setup +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +### Multiple Prompts +Get into the example folder +```bash +cd examples/offline_inference/qwen2_5_omni +``` +Then run the command below. Note: for processing large volume data, it uses py_generator mode, which will return a python generator from Omni class. +```bash +bash run_multiple_prompts.sh +``` + +### Single Prompt +Get into the example folder +```bash +cd examples/offline_inference/qwen2_5_omni +``` +Then run the command below. +```bash +bash run_single_prompt.sh +``` + +### Modality control +If you want to control output modalities, e.g. only output text, you can run the command below: +```bash +python end2end.py --output-wav output_audio \ + --query-type mixed_modalities \ + --modalities text +``` + +#### Using Local Media Files +The `end2end.py` script supports local media files (audio, video, image) via CLI arguments: + +```bash +# Use single local media files +python end2end.py --query-type use_image --image-path /path/to/image.jpg +python end2end.py --query-type use_video --video-path /path/to/video.mp4 +python end2end.py --query-type use_audio --audio-path /path/to/audio.wav + +# Combine multiple local media files +python end2end.py --query-type mixed_modalities \ + --video-path /path/to/video.mp4 \ + --image-path /path/to/image.jpg \ + --audio-path /path/to/audio.wav + +# Use audio from video file +python end2end.py --query-type use_audio_in_video --video-path /path/to/video.mp4 + +``` + +If media file paths are not provided, the script will use default assets. Supported query types: +- `use_image`: Image input only +- `use_video`: Video input only +- `use_audio`: Audio input only +- `mixed_modalities`: Audio + image + video +- `use_audio_in_video`: Extract audio from video +- `text`: Text-only query + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py new file mode 100644 index 0000000000000000000000000000000000000000..63b08202a8f9a84ce70c96bcee93c241b32d9f0a --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -0,0 +1,547 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This example shows how to use vLLM-Omni for running offline inference +with the correct prompt format on Qwen2.5-Omni +""" + +import os +import time +from typing import NamedTuple + +import librosa +import numpy as np +import soundfile as sf +from PIL import Image +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset, video_to_ndarrays +from vllm.multimodal.image import convert_image_mode +from vllm.sampling_params import SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni.entrypoints.omni import Omni + +SEED = 42 + + +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] + + +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + +default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech." +) + + +def get_text_query(question: str = None) -> QueryResult: + if question is None: + question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + return QueryResult( + inputs={ + "prompt": prompt, + }, + limit_mm_per_prompt={}, + ) + + +def get_mixed_modalities_query( + video_path: str | None = None, + image_path: str | None = None, + audio_path: str | None = None, + num_frames: int = 16, + sampling_rate: int = 16000, +) -> QueryResult: + question = "What is recited in the audio? What is the content of this image? Why is this video funny?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|vision_bos|><|IMAGE|><|vision_eos|>" + "<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + # Load video + if video_path: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + video_frames = video_to_ndarrays(video_path, num_frames=num_frames) + else: + video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays + + # Load image + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + pil_image = Image.open(image_path) + image_data = convert_image_mode(pil_image, "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + # Load audio + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (audio_signal.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": audio_data, + "image": image_data, + "video": video_frames, + }, + }, + limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1}, + ) + + +def get_use_audio_in_video_query( + video_path: str | None = None, num_frames: int = 16, sampling_rate: int = 16000 +) -> QueryResult: + question = "Describe the content of the video, then convert what the baby say into text." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|><|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if video_path: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + video_frames = video_to_ndarrays(video_path, num_frames=num_frames) + # Extract audio from video file + audio_signal, sr = librosa.load(video_path, sr=sampling_rate) + audio = (audio_signal.astype(np.float32), sr) + else: + asset = VideoAsset(name="baby_reading", num_frames=num_frames) + video_frames = asset.np_ndarrays + audio = asset.get_audio(sampling_rate=sampling_rate) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": video_frames, + "audio": audio, + }, + "mm_processor_kwargs": { + "use_audio_in_video": True, + }, + }, + limit_mm_per_prompt={"audio": 1, "video": 1}, + ) + + +def get_multi_audios_query(audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult: + question = "Are these two audio clips the same?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (audio_signal.astype(np.float32), sr) + # Use the provided audio as the first audio, default as second + audio_list = [ + audio_data, + AudioAsset("mary_had_lamb").audio_and_sample_rate, + ] + else: + audio_list = [ + AudioAsset("winning_call").audio_and_sample_rate, + AudioAsset("mary_had_lamb").audio_and_sample_rate, + ] + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": audio_list, + }, + }, + limit_mm_per_prompt={ + "audio": 2, + }, + ) + + +def get_image_query(question: str = None, image_path: str | None = None) -> QueryResult: + if question is None: + question = "What is the content of this image?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + pil_image = Image.open(image_path) + image_data = convert_image_mode(pil_image, "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "image": image_data, + }, + }, + limit_mm_per_prompt={"image": 1}, + ) + + +def get_video_query(question: str = None, video_path: str | None = None, num_frames: int = 16) -> QueryResult: + if question is None: + question = "Why is this video funny?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if video_path: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + video_frames = video_to_ndarrays(video_path, num_frames=num_frames) + else: + video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": video_frames, + }, + }, + limit_mm_per_prompt={"video": 1}, + ) + + +def get_audio_query(question: str = None, audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult: + if question is None: + question = "What is the content of this audio?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (audio_signal.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": audio_data, + }, + }, + limit_mm_per_prompt={"audio": 1}, + ) + + +query_map = { + "use_mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "use_multi_audios": get_multi_audios_query, + "use_image": get_image_query, + "use_video": get_video_query, + "use_audio": get_audio_query, + "text": get_text_query, +} + + +def main(args): + model_name = "Qwen/Qwen2.5-Omni-7B" + + # Get paths from args + video_path = getattr(args, "video_path", None) + image_path = getattr(args, "image_path", None) + audio_path = getattr(args, "audio_path", None) + num_frames = getattr(args, "num_frames", 16) + sampling_rate = getattr(args, "sampling_rate", 16000) + + # Get the query function and call it with appropriate parameters + query_func = query_map[args.query_type] + if args.query_type == "mixed_modalities": + query_result = query_func( + video_path=video_path, + image_path=image_path, + audio_path=audio_path, + num_frames=num_frames, + sampling_rate=sampling_rate, + ) + elif args.query_type == "use_audio_in_video": + query_result = query_func(video_path=video_path, num_frames=num_frames, sampling_rate=sampling_rate) + elif args.query_type == "multi_audios": + query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate) + elif args.query_type == "use_image": + query_result = query_func(image_path=image_path) + elif args.query_type == "use_video": + query_result = query_func(video_path=video_path, num_frames=num_frames) + elif args.query_type == "use_audio": + query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate) + else: + query_result = query_func() + omni_llm = Omni( + model=model_name, + log_stats=args.enable_stats, + stage_init_timeout=args.stage_init_timeout, + batch_timeout=args.batch_timeout, + init_timeout=args.init_timeout, + shm_threshold_bytes=args.shm_threshold_bytes, + ) + thinker_sampling_params = SamplingParams( + temperature=0.0, # Deterministic - no randomness + top_p=1.0, # Disable nucleus sampling + top_k=-1, # Disable top-k sampling + max_tokens=2048, + seed=SEED, # Fixed seed for sampling + detokenize=True, + repetition_penalty=1.1, + ) + talker_sampling_params = SamplingParams( + temperature=0.9, + top_p=0.8, + top_k=40, + max_tokens=2048, + seed=SEED, # Fixed seed for sampling + detokenize=True, + repetition_penalty=1.05, + stop_token_ids=[8294], + ) + code2wav_sampling_params = SamplingParams( + temperature=0.0, # Deterministic - no randomness + top_p=1.0, # Disable nucleus sampling + top_k=-1, # Disable top-k sampling + max_tokens=2048, + seed=SEED, # Fixed seed for sampling + detokenize=True, + repetition_penalty=1.1, + ) + + sampling_params_list = [ + thinker_sampling_params, + talker_sampling_params, + code2wav_sampling_params, + ] + + if args.txt_prompts is None: + prompts = [query_result.inputs for _ in range(args.num_prompts)] + else: + assert args.query_type == "text", "txt-prompts is only supported for text query type" + with open(args.txt_prompts, encoding="utf-8") as f: + lines = [ln.strip() for ln in f.readlines()] + prompts = [get_text_query(ln).inputs for ln in lines if ln != ""] + print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}") + + if args.modalities is not None: + output_modalities = args.modalities.split(",") + for i, prompt in enumerate(prompts): + prompt["modalities"] = output_modalities + + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + if profiler_enabled: + omni_llm.start_profile(stages=[0]) + omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator) + + # Determine output directory: prefer --output-dir; fallback to --output-wav + output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav + os.makedirs(output_dir, exist_ok=True) + + total_requests = len(prompts) + processed_count = 0 + for stage_outputs in omni_generator: + if stage_outputs.final_output_type == "text": + for output in stage_outputs.request_output: + request_id = output.request_id + text_output = output.outputs[0].text + # Save aligned text file per request + prompt_text = output.prompt + out_txt = os.path.join(output_dir, f"{request_id}.txt") + lines = [] + lines.append("Prompt:\n") + lines.append(str(prompt_text) + "\n") + lines.append("vllm_text_output:\n") + lines.append(str(text_output).strip() + "\n") + try: + with open(out_txt, "w", encoding="utf-8") as f: + f.writelines(lines) + except Exception as e: + print(f"[Warn] Failed writing text file {out_txt}: {e}") + print(f"Request ID: {request_id}, Text saved to {out_txt}") + elif stage_outputs.final_output_type == "audio": + for output in stage_outputs.request_output: + request_id = output.request_id + audio_tensor = output.outputs[0].multimodal_output["audio"] + output_wav = os.path.join(output_dir, f"output_{request_id}.wav") + sf.write(output_wav, audio_tensor.detach().cpu().numpy(), samplerate=24000) + print(f"Request ID: {request_id}, Saved audio to {output_wav}") + + processed_count += len(stage_outputs.request_output) + if profiler_enabled and processed_count >= total_requests: + print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") + # Stop the profiler while workers are still alive + omni_llm.stop_profile() + + print("[Info] Waiting 30s for workers to write massive trace files to disk...") + time.sleep(30) + print("[Info] Trace export wait finished.") + + omni_llm.close() + + +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="use_mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--enable-stats", + action="store_true", + default=False, + help="Enable writing detailed statistics (default: disabled)", + ) + parser.add_argument( + "--stage-init-timeout", + type=int, + default=300, + help="Timeout for initializing a single stage in seconds (default: 300)", + ) + parser.add_argument( + "--batch-timeout", + type=int, + default=5, + help="Timeout for batching in seconds (default: 5)", + ) + parser.add_argument( + "--init-timeout", + type=int, + default=300, + help="Timeout for initializing stages in seconds (default: 300)", + ) + parser.add_argument( + "--shm-threshold-bytes", + type=int, + default=65536, + help="Threshold for using shared memory in bytes (default: 65536)", + ) + parser.add_argument( + "--output-wav", + default="output_audio", + help="[Deprecated] Output wav directory (use --output-dir).", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1, + help="Number of prompts to generate.", + ) + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one prompt per line (preferred).", + ) + parser.add_argument( + "--video-path", + "-v", + type=str, + default=None, + help="Path to local video file. If not provided, uses default video asset.", + ) + parser.add_argument( + "--image-path", + "-i", + type=str, + default=None, + help="Path to local image file. If not provided, uses default image asset.", + ) + parser.add_argument( + "--audio-path", + "-a", + type=str, + default=None, + help="Path to local audio file. If not provided, uses default audio asset.", + ) + parser.add_argument( + "--num-frames", + type=int, + default=16, + help="Number of frames to extract from video (default: 16).", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=16000, + help="Sampling rate for audio loading (default: 16000).", + ) + parser.add_argument( + "--worker-backend", type=str, default="multi_process", choices=["multi_process", "ray"], help="backend" + ) + parser.add_argument( + "--ray-address", + type=str, + default=None, + help="Address of the Ray cluster.", + ) + parser.add_argument( + "--modalities", + type=str, + default=None, + help="Modalities to use for the prompts.", + ) + parser.add_argument( + "--py-generator", + action="store_true", + default=False, + help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen2_5_omni/extract_prompts.py b/examples/offline_inference/qwen2_5_omni/extract_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..dce0788dbf48c9a5a30c244eecaca1a36389c11d --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/extract_prompts.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +import argparse + + +def extract_prompt(line: str) -> str | None: + # Extract the content between the first '|' and the second '|' + i = line.find("|") + if i == -1: + return None + j = line.find("|", i + 1) + if j == -1: + return None + return line[i + 1 : j].strip() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--input", "-i", required=True, help="Input .lst file path") + parser.add_argument("--output", "-o", required=True, help="Output file path") + parser.add_argument( + "--topk", + "-k", + type=int, + default=100, + help="Extract the top K prompts (default: 100)", + ) + args = parser.parse_args() + + prompts = [] + with open(args.input, encoding="utf-8", errors="ignore") as f: + for line in f: + if len(prompts) >= args.topk: + break + p = extract_prompt(line.rstrip("\n")) + if p: + prompts.append(p) + + with open(args.output, "w", encoding="utf-8") as f: + for p in prompts: + f.write(p + "\n") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh b/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh new file mode 100644 index 0000000000000000000000000000000000000000..b7c8edd38b93d8660f4aa89c4499783dd0518577 --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh @@ -0,0 +1,4 @@ +python end2end.py --output-wav output_audio \ + --query-type text \ + --txt-prompts ../qwen3_omni/text_prompts_10.txt \ + --py-generator diff --git a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh new file mode 100644 index 0000000000000000000000000000000000000000..c8e4cd2cbf3861fe1dfa2e3239773d2e0eb120e9 --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh @@ -0,0 +1,2 @@ +python end2end.py --output-wav output_audio \ + --query-type use_mixed_modalities diff --git a/examples/offline_inference/qwen3_omni/README.md b/examples/offline_inference/qwen3_omni/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b7eee8f74efdf307e14db2d195de6bcfed0c8ef1 --- /dev/null +++ b/examples/offline_inference/qwen3_omni/README.md @@ -0,0 +1,73 @@ +# Qwen3-Omni + +## Setup +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +### Multiple Prompts +Get into the example folder +```bash +cd examples/offline_inference/qwen3_omni +``` +Then run the command below. Note: for processing large volume data, it uses py_generator mode, which will return a python generator from Omni class. +```bash +bash run_multiple_prompts.sh +``` +### Single Prompt +Get into the example folder +```bash +cd examples/offline_inference/qwen3_omni +``` +Then run the command below. +```bash +bash run_single_prompt.sh +``` +If you have not enough memory, you can set thinker with tensor parallel. Just run the command below. +```bash +bash run_single_prompt_tp.sh +``` + +### Modality control +If you want to control output modalities, e.g. only output text, you can run the command below: +```bash +python end2end.py --output-wav output_audio \ + --query-type use_audio \ + --modalities text +``` + +#### Using Local Media Files +The `end2end.py` script supports local media files (audio, video, image) via command-line arguments: + +```bash +# Use local video file +python end2end.py --query-type use_video --video-path /path/to/video.mp4 + +# Use local image file +python end2end.py --query-type use_image --image-path /path/to/image.jpg + +# Use local audio file +python end2end.py --query-type use_audio --audio-path /path/to/audio.wav + +# Combine multiple local media files +python end2end.py --query-type mixed_modalities \ + --video-path /path/to/video.mp4 \ + --image-path /path/to/image.jpg \ + --audio-path /path/to/audio.wav +``` + +If media file paths are not provided, the script will use default assets. Supported query types: +- `use_video`: Video input +- `use_image`: Image input +- `use_audio`: Audio input +- `text`: Text-only query +- `multi_audios`: Multiple audio inputs +- `mixed_modalities`: Combination of video, image, and audio inputs + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py new file mode 100644 index 0000000000000000000000000000000000000000..0b484c5077edf686975c5518f0447017fe11dcb8 --- /dev/null +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This example shows how to use vLLM for running offline inference +with the correct prompt format on Qwen3-Omni (thinker only). +""" + +import os +import time +from typing import NamedTuple + +import librosa +import numpy as np +import soundfile as sf +import vllm +from PIL import Image +from vllm import SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset, video_to_ndarrays +from vllm.multimodal.image import convert_image_mode +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni.entrypoints.omni import Omni + +SEED = 42 + + +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] + + +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + +default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech." +) + + +def get_text_query(question: str = None) -> QueryResult: + if question is None: + question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + return QueryResult( + inputs={ + "prompt": prompt, + }, + limit_mm_per_prompt={}, + ) + + +def get_video_query(question: str = None, video_path: str | None = None, num_frames: int = 16) -> QueryResult: + if question is None: + question = "Why is this video funny?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if video_path: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + video_frames = video_to_ndarrays(video_path, num_frames=num_frames) + else: + video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": video_frames, + }, + }, + limit_mm_per_prompt={"video": 1}, + ) + + +def get_image_query(question: str = None, image_path: str | None = None) -> QueryResult: + if question is None: + question = "What is the content of this image?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + pil_image = Image.open(image_path) + image_data = convert_image_mode(pil_image, "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "image": image_data, + }, + }, + limit_mm_per_prompt={"image": 1}, + ) + + +def get_audio_query(question: str = None, audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult: + if question is None: + question = "What is the content of this audio?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (audio_signal.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": audio_data, + }, + }, + limit_mm_per_prompt={"audio": 1}, + ) + + +def get_mixed_modalities_query( + video_path: str | None = None, + image_path: str | None = None, + audio_path: str | None = None, + num_frames: int = 16, + sampling_rate: int = 16000, +) -> QueryResult: + question = "What is recited in the audio? What is the content of this image? Why is this video funny?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>" + "<|vision_start|><|image_pad|><|vision_end|>" + "<|vision_start|><|video_pad|><|vision_end|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + # Load video + if video_path: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + video_frames = video_to_ndarrays(video_path, num_frames=num_frames) + else: + video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays + + # Load image + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + pil_image = Image.open(image_path) + image_data = convert_image_mode(pil_image, "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + # Load audio + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (audio_signal.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": audio_data, + "image": image_data, + "video": video_frames, + }, + }, + limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1}, + ) + + +def get_multi_audios_query() -> QueryResult: + question = "Are these two audio clips the same?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>" + "<|audio_start|><|audio_pad|><|audio_end|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": [ + AudioAsset("winning_call").audio_and_sample_rate, + AudioAsset("mary_had_lamb").audio_and_sample_rate, + ], + }, + }, + limit_mm_per_prompt={ + "audio": 2, + }, + ) + + +# def get_use_audio_in_video_query(video_path: str | None = None) -> QueryResult: +# question = ( +# "Describe the content of the video in details, then convert what the " +# "baby say into text." +# ) +# prompt = ( +# f"<|im_start|>system\n{default_system}<|im_end|>\n" +# "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>" +# f"{question}<|im_end|>\n" +# f"<|im_start|>assistant\n" +# ) +# if video_path: +# if not os.path.exists(video_path): +# raise FileNotFoundError(f"Video file not found: {video_path}") +# video_frames = video_to_ndarrays(video_path, num_frames=16) +# else: +# video_frames = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays +# audio = extract_video_audio(video_path, sampling_rate=16000) +# return QueryResult( +# inputs={ +# "prompt": prompt, +# "multi_modal_data": { +# "video": video_frames, +# "audio": audio, +# }, +# "mm_processor_kwargs": { +# "use_audio_in_video": True, +# }, +# }, +# limit_mm_per_prompt={"audio": 1, "video": 1}, +# ) +def get_use_audio_in_video_query() -> QueryResult: + question = "Describe the content of the video in details, then convert what the baby say into text." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + asset = VideoAsset(name="baby_reading", num_frames=16) + audio = asset.get_audio(sampling_rate=16000) + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": asset.np_ndarrays, + "audio": audio, + }, + "mm_processor_kwargs": { + "use_audio_in_video": True, + }, + }, + limit_mm_per_prompt={"audio": 1, "video": 1}, + ) + + +query_map = { + "text": get_text_query, + "use_audio": get_audio_query, + "use_image": get_image_query, + "use_video": get_video_query, + "use_multi_audios": get_multi_audios_query, + "use_mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, +} + + +def main(args): + model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + print("=" * 20, "\n", f"vllm version: {vllm.__version__}", "\n", "=" * 20) + + # Get paths from args + video_path = getattr(args, "video_path", None) + image_path = getattr(args, "image_path", None) + audio_path = getattr(args, "audio_path", None) + + # Get the query function and call it with appropriate parameters + query_func = query_map[args.query_type] + if args.query_type == "use_video": + query_result = query_func(video_path=video_path, num_frames=getattr(args, "num_frames", 16)) + elif args.query_type == "use_image": + query_result = query_func(image_path=image_path) + elif args.query_type == "use_audio": + query_result = query_func(audio_path=audio_path, sampling_rate=getattr(args, "sampling_rate", 16000)) + elif args.query_type == "mixed_modalities": + query_result = query_func( + video_path=video_path, + image_path=image_path, + audio_path=audio_path, + num_frames=getattr(args, "num_frames", 16), + sampling_rate=getattr(args, "sampling_rate", 16000), + ) + elif args.query_type == "multi_audios": + query_result = query_func() + elif args.query_type == "use_audio_in_video": + query_result = query_func() + else: + query_result = query_func() + + omni_llm = Omni( + model=model_name, + stage_configs_path=args.stage_configs_path, + log_stats=args.enable_stats, + stage_init_timeout=args.stage_init_timeout, + ) + + thinker_sampling_params = SamplingParams( + temperature=0.9, + top_p=0.9, + top_k=-1, + max_tokens=1200, + repetition_penalty=1.05, + logit_bias={}, + seed=SEED, + ) + + talker_sampling_params = SamplingParams( + temperature=0.9, + top_k=50, + max_tokens=4096, + seed=SEED, + detokenize=False, + repetition_penalty=1.05, + stop_token_ids=[2150], # TALKER_CODEC_EOS_TOKEN_ID + ) + + # Sampling parameters for Code2Wav stage (audio generation) + code2wav_sampling_params = SamplingParams( + temperature=0.0, + top_p=1.0, + top_k=-1, + max_tokens=4096 * 16, + seed=SEED, + detokenize=True, + repetition_penalty=1.1, + ) + + sampling_params_list = [ + thinker_sampling_params, + talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni + code2wav_sampling_params, + ] + + if args.txt_prompts is None: + prompts = [query_result.inputs for _ in range(args.num_prompts)] + else: + assert args.query_type == "text", "txt-prompts is only supported for text query type" + with open(args.txt_prompts, encoding="utf-8") as f: + lines = [ln.strip() for ln in f.readlines()] + prompts = [get_text_query(ln).inputs for ln in lines if ln != ""] + print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}") + + if args.modalities is not None: + output_modalities = args.modalities.split(",") + for i, prompt in enumerate(prompts): + prompt["modalities"] = output_modalities + + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + if profiler_enabled: + omni_llm.start_profile(stages=[0]) + omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator) + # Determine output directory: prefer --output-dir; fallback to --output-wav + output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav + os.makedirs(output_dir, exist_ok=True) + + total_requests = len(prompts) + processed_count = 0 + + print(f"query type: {args.query_type}") + + for stage_outputs in omni_generator: + if stage_outputs.final_output_type == "text": + for output in stage_outputs.request_output: + request_id = output.request_id + text_output = output.outputs[0].text + # Save aligned text file per request + prompt_text = output.prompt + out_txt = os.path.join(output_dir, f"{request_id}.txt") + lines = [] + lines.append("Prompt:\n") + lines.append(str(prompt_text) + "\n") + lines.append("vllm_text_output:\n") + lines.append(str(text_output).strip() + "\n") + try: + with open(out_txt, "w", encoding="utf-8") as f: + f.writelines(lines) + except Exception as e: + print(f"[Warn] Failed writing text file {out_txt}: {e}") + print(f"Request ID: {request_id}, Text saved to {out_txt}") + elif stage_outputs.final_output_type == "audio": + for output in stage_outputs.request_output: + request_id = output.request_id + audio_tensor = output.outputs[0].multimodal_output["audio"] + output_wav = os.path.join(output_dir, f"output_{request_id}.wav") + + # Convert to numpy array and ensure correct format + audio_numpy = audio_tensor.float().detach().cpu().numpy() + + # Ensure audio is 1D (flatten if needed) + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() + + # Save audio file with explicit WAV format + sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV") + print(f"Request ID: {request_id}, Saved audio to {output_wav}") + + processed_count += len(stage_outputs.request_output) + if profiler_enabled and processed_count >= total_requests: + print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") + # Stop the profiler while workers are still alive + omni_llm.stop_profile() + + print("[Info] Waiting 30s for workers to write trace files to disk...") + time.sleep(30) + print("[Info] Trace export wait time finished.") + omni_llm.close() + + +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="use_mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--enable-stats", + action="store_true", + default=False, + help="Enable writing detailed statistics (default: disabled)", + ) + parser.add_argument( + "--stage-init-timeout", + type=int, + default=300, + help="Timeout for initializing a single stage in seconds (default: 300)", + ) + parser.add_argument( + "--batch-timeout", + type=int, + default=5, + help="Timeout for batching in seconds (default: 5)", + ) + parser.add_argument( + "--init-timeout", + type=int, + default=300, + help="Timeout for initializing stages in seconds (default: 300)", + ) + parser.add_argument( + "--shm-threshold-bytes", + type=int, + default=65536, + help="Threshold for using shared memory in bytes (default: 65536)", + ) + parser.add_argument( + "--output-wav", + default="output_audio", + help="[Deprecated] Output wav directory (use --output-dir).", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1, + help="Number of prompts to generate.", + ) + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one prompt per line (preferred).", + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=None, + help="Path to a stage configs file.", + ) + parser.add_argument( + "--video-path", + "-v", + type=str, + default=None, + help="Path to local video file. If not provided, uses default video asset.", + ) + parser.add_argument( + "--image-path", + "-i", + type=str, + default=None, + help="Path to local image file. If not provided, uses default image asset.", + ) + parser.add_argument( + "--audio-path", + "-a", + type=str, + default=None, + help="Path to local audio file. If not provided, uses default audio asset.", + ) + parser.add_argument( + "--num-frames", + type=int, + default=16, + help="Number of frames to extract from video (default: 16).", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=16000, + help="Sampling rate for audio loading (default: 16000).", + ) + parser.add_argument( + "--log-dir", + type=str, + default="logs", + help="Log directory (default: logs).", + ) + parser.add_argument( + "--modalities", + type=str, + default=None, + help="Output modalities to use for the prompts.", + ) + parser.add_argument( + "--py-generator", + action="store_true", + default=False, + help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen3_omni/run_multiple_prompts.sh b/examples/offline_inference/qwen3_omni/run_multiple_prompts.sh new file mode 100644 index 0000000000000000000000000000000000000000..a48068af938e756e415edb23ed6d4d529110f63c --- /dev/null +++ b/examples/offline_inference/qwen3_omni/run_multiple_prompts.sh @@ -0,0 +1,4 @@ +python end2end.py --output-wav output_audio \ + --query-type text \ + --txt-prompts text_prompts_10.txt \ + --py-generator diff --git a/examples/offline_inference/qwen3_omni/run_single_prompt.sh b/examples/offline_inference/qwen3_omni/run_single_prompt.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6ca09da6e23c4eee7be756873eed22b1525c61a --- /dev/null +++ b/examples/offline_inference/qwen3_omni/run_single_prompt.sh @@ -0,0 +1,2 @@ +python end2end.py --output-wav output_audio \ + --query-type use_audio diff --git a/examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh b/examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh new file mode 100644 index 0000000000000000000000000000000000000000..0cb459eab77e636dc3774d63a5818bee9b17b149 --- /dev/null +++ b/examples/offline_inference/qwen3_omni/run_single_prompt_tp.sh @@ -0,0 +1,5 @@ +python end2end.py --output-wav output_audio \ + --query-type use_audio \ + --stage-init-timeout 300 + +# stage-init-timeout sets the maximum wait to avoid two vLLM stages initializing at the same time on the same card. diff --git a/examples/offline_inference/qwen3_omni/text_prompts_10.txt b/examples/offline_inference/qwen3_omni/text_prompts_10.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e5fbe0e3dba0c5fe17f5b125564122269e998cf --- /dev/null +++ b/examples/offline_inference/qwen3_omni/text_prompts_10.txt @@ -0,0 +1,10 @@ +What is the capital of France? +How many planets are in our solar system? +What is the largest ocean on Earth? +Who wrote the novel "1984"? +What is the chemical symbol for water? +What year did World War II end? +What is the tallest mountain in the world? +What is the speed of light in vacuum? +Who painted the Mona Lisa? +What is the smallest prime number? diff --git a/examples/offline_inference/qwen3_tts/README.md b/examples/offline_inference/qwen3_tts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eab66fe14ea019d9110ce49db9e31d2403041d0e --- /dev/null +++ b/examples/offline_inference/qwen3_tts/README.md @@ -0,0 +1,93 @@ +# Qwen3-TTS Offline Inference + +This directory contains an offline demo for running Qwen3 TTS models with vLLM Omni. It builds task-specific inputs and generates WAV files locally. + +## Model Overview + +Qwen3 TTS provides multiple task variants for speech generation: + +- **CustomVoice**: Generate speech with a known speaker identity (speaker ID) and optional instruction. +- **VoiceDesign**: Generate speech from text plus a descriptive instruction that designs a new voice. +- **Base**: Voice cloning using a reference audio + reference transcript, with optional mode selection. + +## Setup +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +### ROCm Dependencies + +You will need to install these two dependencies `onnxruntime-rocm` and `sox`. + +``` +pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm +pip install onnxruntime-rocm sox +``` + +## Quick Start + +Run a single sample for a task: + +``` +python end2end.py --query-type CustomVoice +``` + +Generated audio files are saved to `output_audio/` by default. + +## Task Usage + +### CustomVoice + +Single sample: + +``` +python end2end.py --query-type CustomVoice +``` + +Batch sample (multiple prompts in one run): + +``` +python end2end.py --query-type CustomVoice --use-batch-sample +``` + +### VoiceDesign + +Single sample: + +``` +python end2end.py --query-type VoiceDesign +``` + +Batch sample: + +``` +python end2end.py --query-type VoiceDesign --use-batch-sample +``` + +### Base (Voice Clone) + +Single sample: + +``` +python end2end.py --query-type Base +``` + +Batch sample: + +``` +python end2end.py --query-type Base --use-batch-sample +``` + +Mode selection for Base: + +- `--mode-tag icl` (default): standard mode +- `--mode-tag xvec_only`: enable `x_vector_only_mode` in the request + +Examples: + +``` +python end2end.py --query-type Base --mode-tag icl +``` + +## Notes + +- The script uses the model paths embedded in `end2end.py`. Update them if your local cache path differs. +- Use `--output-dir` (preferred) or `--output-wav` to change the output folder. diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py new file mode 100644 index 0000000000000000000000000000000000000000..18bb9dedcb1347a657b32b2060570282d9db13da --- /dev/null +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -0,0 +1,381 @@ +"""Offline inference demo for Qwen3 TTS via vLLM Omni. + +Provides single and batch sample inputs for CustomVoice, VoiceDesign, and Base +tasks, then runs Omni generation and saves output wav files. +""" + +import os +from typing import NamedTuple + +import soundfile as sf + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +from vllm import SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import Omni + + +class QueryResult(NamedTuple): + """Container for a prepared Omni request.""" + + inputs: dict + model_name: str + + +def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: + """Build CustomVoice sample inputs. + + Args: + use_batch_sample: When True, return a batch of prompts; otherwise a single prompt. + + Returns: + QueryResult with Omni inputs and the CustomVoice model path. + """ + task_type = "CustomVoice" + if use_batch_sample: + texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon."] + instructs = ["", "Very happy."] + languages = ["Chinese", "English"] + speakers = ["Vivian", "Ryan"] + inputs = [] + for text, instruct, language, speaker in zip(texts, instructs, languages, speakers): + prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + inputs.append( + { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "text": [text], + "instruct": [instruct], + "language": [language], + "speaker": [speaker], + "max_new_tokens": [2048], + }, + } + ) + else: + text = "其实我真的有发现,我是一个特别善于观察别人情绪的人。" + language = "Chinese" + speaker = "Vivian" + instruct = "用特别愤怒的语气说" + prompts = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + inputs = { + "prompt": prompts, + "additional_information": { + "task_type": [task_type], + "text": [text], + "language": [language], + "speaker": [speaker], + "instruct": [instruct], + "max_new_tokens": [2048], + }, + } + return QueryResult( + inputs=inputs, + model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + ) + + +def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult: + """Build VoiceDesign sample inputs. + + Args: + use_batch_sample: When True, return a batch of prompts; otherwise a single prompt. + + Returns: + QueryResult with Omni inputs and the VoiceDesign model path. + """ + task_type = "VoiceDesign" + if use_batch_sample: + texts = [ + "哥哥,你回来啦,人家等了你好久好久了,要抱抱!", + "It's in the top drawer... wait, it's empty? No way, that's impossible! I'm sure I put it there!", + ] + instructs = [ + "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。", + "Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice.", + ] + languages = ["Chinese", "English"] + inputs = [] + for text, instruct, language in zip(texts, instructs, languages): + prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + inputs.append( + { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "text": [text], + "language": [language], + "instruct": [instruct], + "max_new_tokens": [2048], + "non_streaming_mode": [True], + }, + } + ) + else: + text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!" + instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。" + language = "Chinese" + prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + inputs = { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "text": [text], + "language": [language], + "instruct": [instruct], + "max_new_tokens": [2048], + "non_streaming_mode": [True], + }, + } + return QueryResult( + inputs=inputs, + model_name="Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", + ) + + +def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> QueryResult: + """Build Base (voice clone) sample inputs. + + Args: + use_batch_sample: When True, return a batch of prompts (Case 2). + mode_tag: "icl" or "xvec_only" to control x_vector_only_mode behavior. + + Returns: + QueryResult with Omni inputs and the Base model path. + """ + task_type = "Base" + ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav" + ref_audio_single = ref_audio_path_1 + ref_text_single = ( + "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you." + ) + syn_text_single = "Good one. Okay, fine, I'm just gonna leave this sock monkey here. Goodbye." + syn_lang_single = "Auto" + x_vector_only_mode = mode_tag == "xvec_only" + if use_batch_sample: + syn_text_batch = [ + "Good one. Okay, fine, I'm just gonna leave this sock monkey here. Goodbye.", + "其实我真的有发现,我是一个特别善于观察别人情绪的人。", + ] + syn_lang_batch = ["Chinese", "English"] + inputs = [] + for text, language in zip(syn_text_batch, syn_lang_batch): + prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + inputs.append( + { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "ref_audio": [ref_audio_single], + "ref_text": [ref_text_single], + "text": [text], + "language": [language], + "x_vector_only_mode": [x_vector_only_mode], + "max_new_tokens": [2048], + }, + } + ) + else: + prompt = f"<|im_start|>assistant\n{syn_text_single}<|im_end|>\n<|im_start|>assistant\n" + inputs = { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "ref_audio": [ref_audio_single], + "ref_text": [ref_text_single], + "text": [syn_text_single], + "language": [syn_lang_single], + "x_vector_only_mode": [x_vector_only_mode], + "max_new_tokens": [2048], + }, + } + return QueryResult( + inputs=inputs, + model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base", + ) + + +def main(args): + """Run offline inference with Omni using prepared sample inputs. + + Args: + args: Parsed CLI args from parse_args(). + """ + query_func = query_map[args.query_type] + if args.query_type in {"CustomVoice", "VoiceDesign"}: + query_result = query_func(use_batch_sample=args.use_batch_sample) + elif args.query_type == "Base": + query_result = query_func( + use_batch_sample=args.use_batch_sample, + mode_tag=args.mode_tag, + ) + else: + query_result = query_func() + + model_name = query_result.model_name + omni = Omni( + model=model_name, + stage_configs_path=args.stage_configs_path, + log_stats=args.enable_stats, + stage_init_timeout=args.stage_init_timeout, + ) + + sampling_params = SamplingParams( + temperature=0.9, + top_p=1.0, + top_k=50, + max_tokens=2048, + seed=42, + detokenize=False, + repetition_penalty=1.05, + ) + + sampling_params_list = [ + sampling_params, + ] + + output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav + os.makedirs(output_dir, exist_ok=True) + + omni_generator = omni.generate(query_result.inputs, sampling_params_list) + for stage_outputs in omni_generator: + for output in stage_outputs.request_output: + request_id = output.request_id + audio_tensor = output.outputs[0].multimodal_output["audio"] + output_wav = os.path.join(output_dir, f"output_{request_id}.wav") + audio_samplerate = output.outputs[0].multimodal_output["sr"].item() + # Convert to numpy array and ensure correct format + audio_numpy = audio_tensor.float().detach().cpu().numpy() + + # Ensure audio is 1D (flatten if needed) + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() + + # Save audio file with explicit WAV format + sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") + print(f"Request ID: {request_id}, Saved audio to {output_wav}") + + +def parse_args(): + """Parse CLI arguments for offline TTS inference. + + Returns: + argparse.Namespace with CLI options. + """ + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="CustomVoice", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--enable-stats", + action="store_true", + default=False, + help="Enable writing detailed statistics (default: disabled)", + ) + parser.add_argument( + "--stage-init-timeout", + type=int, + default=300, + help="Timeout for initializing a single stage in seconds (default: 300)", + ) + parser.add_argument( + "--batch-timeout", + type=int, + default=5, + help="Timeout for batching in seconds (default: 5)", + ) + parser.add_argument( + "--init-timeout", + type=int, + default=300, + help="Timeout for initializing stages in seconds (default: 300)", + ) + parser.add_argument( + "--shm-threshold-bytes", + type=int, + default=65536, + help="Threshold for using shared memory in bytes (default: 65536)", + ) + parser.add_argument( + "--output-wav", + default="output_audio", + help="[Deprecated] Output wav directory (use --output-dir).", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1, + help="Number of prompts to generate.", + ) + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one prompt per line (preferred).", + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=None, + help="Path to a stage configs file.", + ) + parser.add_argument( + "--audio-path", + "-a", + type=str, + default=None, + help="Path to local audio file. If not provided, uses default audio asset.", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=16000, + help="Sampling rate for audio loading (default: 16000).", + ) + parser.add_argument( + "--log-dir", + type=str, + default="logs", + help="Log directory (default: logs).", + ) + parser.add_argument( + "--py-generator", + action="store_true", + default=False, + help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.", + ) + parser.add_argument( + "--use-batch-sample", + action="store_true", + default=False, + help="Use batch input sample for CustomVoice/VoiceDesign/Base query.", + ) + parser.add_argument( + "--mode-tag", + type=str, + default="icl", + choices=["icl", "xvec_only"], + help="Mode tag for Base query x_vector_only_mode (default: icl).", + ) + + return parser.parse_args() + + +query_map = { + "CustomVoice": get_custom_voice_query, + "VoiceDesign": get_voice_design_query, + "Base": get_base_query, +} + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/text_to_audio/README.md b/examples/offline_inference/text_to_audio/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ec1eafe52d321879b69db1e3e2c0650942c9ad8 --- /dev/null +++ b/examples/offline_inference/text_to_audio/README.md @@ -0,0 +1,37 @@ +# Text-To-Audio + +The `stabilityai/stable-audio-open-1.0` pipeline generates audio from text prompts. + +## Prerequisites + +If you use a gated model (e.g., `stabilityai/stable-audio-open-1.0`), ensure you have access: + +1. **Accept Model License**: Visit the model page on Hugging Face (e.g., [stabilityai/stable-audio-open-1.0]) and accept the user agreement. +2. **Authenticate**: Log in to Hugging Face locally to access the gated model. + ```bash + huggingface-cli login + ``` + +## Local CLI Usage + +```bash +python text_to_audio.py \ + --model stabilityai/stable-audio-open-1.0 \ + --prompt "The sound of a hammer hitting a wooden surface" \ + --negative_prompt "Low quality" \ + --seed 42 \ + --guidance_scale 7.0 \ + --audio_length 10.0 \ + --num_inference_steps 100 \ + --output stable_audio_output.wav +``` + +Key arguments: + +- `--prompt`: text description (string). +- `--negative_prompt`: negative prompt for classifier-free guidance. +- `--seed`: integer seed for deterministic generation. +- `--guidance_scale`: classifier-free guidance scale. +- `--audio_length`: audio duration in seconds. +- `--num_inference_steps`: diffusion sampling steps.(more steps = higher quality, slower). +- `--output`: path to save the generated WAV file. diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9efcca5ff8a4eeacba7b0aa6b8fd166e9710f8 --- /dev/null +++ b/examples/offline_inference/text_to_audio/text_to_audio.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example script for text-to-audio generation using Stable Audio Open. + +This script demonstrates how to generate audio from text prompts using +the Stable Audio Open model with vLLM-Omni. + +Usage: + python text_to_audio.py --prompt "The sound of a dog barking" + python text_to_audio.py --prompt "A piano playing a gentle melody" --audio_length 10.0 + python text_to_audio.py --prompt "Thunder and rain sounds" --negative_prompt "Low quality" +""" + +import argparse +import time +from pathlib import Path + +import numpy as np +import torch + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate audio with Stable Audio Open.") + parser.add_argument( + "--model", + default="stabilityai/stable-audio-open-1.0", + help="Stable Audio model name or local path.", + ) + parser.add_argument( + "--prompt", + default="The sound of a hammer hitting a wooden surface.", + help="Text prompt for audio generation.", + ) + parser.add_argument( + "--negative_prompt", + default="Low quality.", + help="Negative prompt for classifier-free guidance.", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for deterministic results.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7.0, + help="Classifier-free guidance scale.", + ) + parser.add_argument( + "--audio_start", + type=float, + default=0.0, + help="Audio start time in seconds.", + ) + parser.add_argument( + "--audio_length", + type=float, + default=10.0, + help="Audio length in seconds (max ~47s for stable-audio-open-1.0).", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="Number of denoising steps for the diffusion sampler.", + ) + parser.add_argument( + "--num_waveforms", + type=int, + default=1, + help="Number of audio waveforms to generate for the given prompt.", + ) + parser.add_argument( + "--output", + type=str, + default="stable_audio_output.wav", + help="Path to save the generated audio (WAV format).", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=44100, + help="Sample rate for output audio (Stable Audio uses 44100 Hz).", + ) + return parser.parse_args() + + +def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 44100): + """Save audio data to a WAV file.""" + try: + import soundfile as sf + + sf.write(output_path, audio_data, sample_rate) + except ImportError: + try: + import scipy.io.wavfile as wav + + # Ensure audio is in the correct format for scipy + if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: + # Normalize to int16 range + audio_data = np.clip(audio_data, -1.0, 1.0) + audio_data = (audio_data * 32767).astype(np.int16) + wav.write(output_path, sample_rate, audio_data) + except ImportError: + raise ImportError( + "Either 'soundfile' or 'scipy' is required to save audio files. " + "Install with: pip install soundfile or pip install scipy" + ) + + +def main(): + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + + print(f"\n{'=' * 60}") + print("Stable Audio Open - Text-to-Audio Generation") + print(f"{'=' * 60}") + print(f" Model: {args.model}") + print(f" Prompt: {args.prompt}") + print(f" Negative prompt: {args.negative_prompt}") + print(f" Audio length: {args.audio_length}s") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Guidance scale: {args.guidance_scale}") + print(f" Seed: {args.seed}") + print(f"{'=' * 60}\n") + + # Initialize Omni with Stable Audio model + omni = Omni(model=args.model) + + # Calculate audio end time + audio_end_in_s = args.audio_start + args.audio_length + + # Time profiling for generation + generation_start = time.perf_counter() + + # Generate audio + outputs = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + }, + OmniDiffusionSamplingParams( + generator=generator, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + num_outputs_per_prompt=args.num_waveforms, + extra_args={ + "audio_start_in_s": args.audio_start, + "audio_end_in_s": audio_end_in_s, + }, + ), + ) + + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + print(f"Total generation time: {generation_time:.2f} seconds") + + # Process and save audio + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".wav" + stem = output_path.stem or "stable_audio_output" + + # Extract audio from omni.generate() outputs + if not outputs: + raise ValueError("No output generated from omni.generate()") + + output = outputs[0] + if not hasattr(output, "request_output") or not output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + request_output = output.request_output[0] + if not hasattr(request_output, "multimodal_output"): + raise ValueError("No multimodal_output found in request_output") + + audio = request_output.multimodal_output.get("audio") + if audio is None: + raise ValueError("No audio output found in request_output") + + # Handle different output formats + if isinstance(audio, torch.Tensor): + audio = audio.cpu().float().numpy() + + # Audio shape is typically [batch, channels, samples] or [channels, samples] + if audio.ndim == 3: + # [batch, channels, samples] + if args.num_waveforms <= 1: + audio_data = audio[0].T # [samples, channels] + save_audio(audio_data, str(output_path), args.sample_rate) + print(f"Saved generated audio to {output_path}") + else: + for idx in range(audio.shape[0]): + audio_data = audio[idx].T # [samples, channels] + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + save_audio(audio_data, str(save_path), args.sample_rate) + print(f"Saved generated audio to {save_path}") + elif audio.ndim == 2: + # [channels, samples] + audio_data = audio.T # [samples, channels] + save_audio(audio_data, str(output_path), args.sample_rate) + print(f"Saved generated audio to {output_path}") + else: + # [samples] - mono audio + save_audio(audio, str(output_path), args.sample_rate) + print(f"Saved generated audio to {output_path}") + + print(f"\nGenerated {args.audio_length}s of audio at {args.sample_rate} Hz") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9c57a621cff4d107e4d84751ec48791f3cc555da --- /dev/null +++ b/examples/offline_inference/text_to_image/README.md @@ -0,0 +1,116 @@ +# Text-To-Image + +This folder provides several entrypoints for experimenting with `Qwen/Qwen-Image` `Qwen/Qwen-Image-2512` `Tongyi-MAI/Z-Image-Turbo` using vLLM-Omni: + +- `text_to_image.py`: command-line script for single image generation with advanced options. +- `web_demo.py`: lightweight Gradio UI for interactive prompt/seed/CFG exploration. + +Note that when you pass in multiple independent prompts, they will be processed sequentially. Batching requests is currently not supported. + +## Basic Usage + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + prompt = "a cup of coffee on the table" + outputs = omni.generate(prompt) + images = outputs[0].request_output[0].images + images[0].save("coffee.png") +``` + +Or put more than one prompt in a request. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + prompts = [ + "a cup of coffee on a table", + "a toy dinosaur on a sandy beach", + "a fox waking up in bed and yawning", + ] + outputs = omni.generate(prompts) + for i, output in enumerate(outputs): + image = output.request_output[0].images[0].save(f"{i}.jpg") +``` + +!!! info + + However, it is not currently recommended to do so + because not all models support batch inference, + and batch requesting mostly does not provide significant performance improvement (despite the impression that it does). + This feature is primarily for the sake of interface compatibility with vLLM and to allow for future improvements. + +!!! info + + For diffusion pipelines, the stage config field `stage_args.[].runtime.max_batch_size` is 1 by default, and the input + list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support + batched inputs, you can [modify this configuration](../../../configuration/stage_configs.md) to let the model accept a longer batch of prompts. + +Apart from string prompt, vLLM-Omni also supports dictionary prompts in the same style as vLLM. +This is useful for models that support negative prompts. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + outputs = omni.generate([ + { + "prompt": "a cup of coffee on a table", + "negative_prompt": "low resolution" + }, + { + "prompt": "a toy dinosaur on a sandy beach", + "negative_prompt": "cinematic, realistic" + } + ]) + for i, output in enumerate(outputs): + image = output.request_output[0].images[0].save(f"{i}.jpg") +``` + +## Local CLI Usage + +```bash +python text_to_image.py \ + --model Tongyi-MAI/Z-Image-Turbo \ + --prompt "a cup of coffee on the table" \ + --seed 42 \ + --cfg_scale 4.0 \ + --num_images_per_prompt 1 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output outputs/coffee.png +``` + +Key arguments: + +- `--prompt`: text description (string). +- `--seed`: integer seed for deterministic sampling. +- `--cfg_scale`: true CFG scale (model-specific guidance strength). +- `--num_images_per_prompt`: number of images to generate per prompt (saves as `output`, `output_1`, ...). +- `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). +- `--height/--width`: output resolution (defaults 1024x1024). +- `--output`: path to save the generated PNG. +- `--vae_use_slicing`: enable VAE slicing for memory optimization. +- `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--enable-cpu-offload`: enable CPU offloading for diffusion models. + +> ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. + +> ℹ️ Qwen-Image currently publishes best-effort presets at `1328x1328`, `1664x928`, `928x1664`, `1472x1140`, `1140x1472`, `1584x1056`, and `1056x1584`. Adjust `--height/--width` accordingly for the most reliable outcomes. + +## Web UI Demo + +Launch the gradio demo: + +```bash +python gradio_demo.py --port 7862 +``` + +Then open `http://localhost:7862/` on your local browser to interact with the web UI. diff --git a/examples/offline_inference/text_to_image/gradio_demo.py b/examples/offline_inference/text_to_image/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..614e40903027004a162001a32d52e2a83ac328ec --- /dev/null +++ b/examples/offline_inference/text_to_image/gradio_demo.py @@ -0,0 +1,237 @@ +import argparse +from functools import lru_cache + +import gradio as gr +import torch + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +ASPECT_RATIOS: dict[str, tuple[int, int]] = { + "1:1": (1328, 1328), + "16:9": (1664, 928), + "9:16": (928, 1664), + "4:3": (1472, 1140), + "3:4": (1140, 1472), + "3:2": (1584, 1056), + "2:3": (1056, 1584), +} +ASPECT_RATIO_CHOICES = [f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items()] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Gradio demo for Qwen-Image offline inference.") + parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.") + parser.add_argument( + "--height", + type=int, + default=1328, + help="Default image height (must match one of the supported presets).", + ) + parser.add_argument( + "--width", + type=int, + default=1328, + help="Default image width (must match one of the supported presets).", + ) + parser.add_argument("--default-prompt", default="a cup of coffee on the table", help="Initial prompt shown in UI.") + parser.add_argument("--default-seed", type=int, default=42, help="Initial seed shown in UI.") + parser.add_argument("--default-cfg-scale", type=float, default=4.0, help="Initial CFG scale shown in UI.") + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Default number of denoising steps shown in the UI.", + ) + parser.add_argument("--ip", default="127.0.0.1", help="Host/IP for Gradio `launch`.") + parser.add_argument("--port", type=int, default=7862, help="Port for Gradio `launch`.") + parser.add_argument("--share", action="store_true", help="Share the Gradio demo publicly.") + args = parser.parse_args() + args.aspect_ratio_label = next( + (ratio for ratio, dims in ASPECT_RATIOS.items() if dims == (args.width, args.height)), + None, + ) + if args.aspect_ratio_label is None: + supported = ", ".join(f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items()) + parser.error(f"Unsupported resolution {args.width}x{args.height}. Please pick one of: {supported}.") + return args + + +@lru_cache(maxsize=1) +def get_omni(model_name: str) -> Omni: + # Enable VAE memory optimizations on NPU + vae_use_slicing = current_omni_platform.is_npu() + vae_use_tiling = current_omni_platform.is_npu() + return Omni( + model=model_name, + vae_use_slicing=vae_use_slicing, + vae_use_tiling=vae_use_tiling, + ) + + +def build_demo(args: argparse.Namespace) -> gr.Blocks: + omni = get_omni(args.model) + + def run_inference( + prompt: str, + seed_value: float, + cfg_scale_value: float, + resolution_choice: str, + num_steps_value: float, + num_images_choice: float, + ): + if not prompt or not prompt.strip(): + raise gr.Error("Please enter a non-empty prompt.") + ratio_label = resolution_choice.split(" ", 1)[0] + if ratio_label not in ASPECT_RATIOS: + raise gr.Error(f"Unsupported aspect ratio: {ratio_label}") + width, height = ASPECT_RATIOS[ratio_label] + try: + seed = int(seed_value) + num_steps = int(num_steps_value) + num_images = int(num_images_choice) + except (TypeError, ValueError) as exc: + raise gr.Error("Seed, inference steps, and number of images must be valid integers.") from exc + if num_steps <= 0: + raise gr.Error("Inference steps must be a positive integer.") + if num_images not in {1, 2, 3, 4}: + raise gr.Error("Number of images must be 1, 2, 3, or 4.") + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) + outputs = omni.generate( + prompt.strip(), + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=generator, + true_cfg_scale=float(cfg_scale_value), + num_inference_steps=num_steps, + num_outputs_per_prompt=num_images, + ), + ) + images_outputs = [] + for output in outputs: + req_out = output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + images = req_out.images + if not images: + raise ValueError("No images found in request_output") + # Extend the list with individual images (not append the entire list) + images_outputs.extend(images) + if len(images_outputs) >= num_images: + break + # Return only the requested number of images + return images_outputs[:num_images] + + with gr.Blocks( + title="vLLM-Omni Web Serving Demo", + css=""" + /* Left column button width */ + .left-column button { + width: 100%; + } + /* Right preview area: fixed height, hide unnecessary buttons */ + .fixed-image { + height: 660px; + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + } + .fixed-image .duplicate-button, + .fixed-image .svelte-drgfj2 { + display: none !important; + } + /* Gallery container: fill available space and center content */ + #image-gallery { + width: 100%; + height: 100%; + display: flex; + align-items: center; + justify-content: center; + } + /* Gallery grid: center horizontally and vertically, set gap */ + #image-gallery .grid { + display: flex; + flex-wrap: wrap; + justify-content: center; + align-items: center; + align-content: center; + gap: 16px; + width: 100%; + height: 100%; + } + /* Gallery grid items: center content */ + #image-gallery .grid > div { + display: flex; + align-items: center; + justify-content: center; + } + /* Gallery images: limit max height, maintain aspect ratio */ + .fixed-image img { + max-height: 660px !important; + width: auto !important; + object-fit: contain; + } + """, + ) as demo: + gr.Markdown("# vLLM-Omni Web Serving Demo") + gr.Markdown(f"**Model:** {args.model}") + + with gr.Row(): + with gr.Column(scale=1, elem_classes="left-column"): + prompt_input = gr.Textbox( + label="Prompt", + value=args.default_prompt, + placeholder="Describe the image you want to generate...", + lines=5, + ) + seed_input = gr.Number(label="Seed", value=args.default_seed, precision=0) + cfg_input = gr.Number(label="CFG Scale", value=args.default_cfg_scale) + steps_input = gr.Number( + label="Inference Steps", + value=args.num_inference_steps, + precision=0, + minimum=1, + ) + aspect_dropdown = gr.Dropdown( + label="Aspect Ratio (W:H)", + choices=ASPECT_RATIO_CHOICES, + value=f"{args.aspect_ratio_label} ({ASPECT_RATIOS[args.aspect_ratio_label][0]}x{ASPECT_RATIOS[args.aspect_ratio_label][1]})", + ) + num_images = gr.Dropdown( + label="Number of images", + choices=["1", "2", "3", "4"], + value="1", + ) + generate_btn = gr.Button("Generate", variant="primary") + with gr.Column(scale=2, elem_classes="fixed-image"): + gallery = gr.Gallery( + label="Preview", + columns=2, + rows=2, + height=660, + allow_preview=True, + show_label=True, + elem_id="image-gallery", + ) + + generate_btn.click( + fn=run_inference, + inputs=[prompt_input, seed_input, cfg_input, aspect_dropdown, steps_input, num_images], + outputs=gallery, + ) + + return demo + + +def main(): + args = parse_args() + demo = build_demo(args) + demo.launch(server_name=args.ip, server_port=args.port, share=args.share) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..a79e5d640d261b2cb04fcfd85994647e33fc8663 --- /dev/null +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import os +import time +from pathlib import Path + +import torch + +from vllm_omni.diffusion.data import DiffusionParallelConfig, logger +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate an image with Qwen-Image.") + parser.add_argument( + "--model", + default="Qwen/Qwen-Image", + help="Diffusion model name or local path. Supported models: " + "Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo, Qwen/Qwen-Image-2512", + ) + parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") + parser.add_argument( + "--negative_prompt", + default=None, + help="negative prompt for classifier-free conditional guidance.", + ) + parser.add_argument("--seed", type=int, default=142, help="Random seed for deterministic results.") + parser.add_argument( + "--cfg_scale", + type=float, + default=4.0, + help="True classifier-free guidance scale specific to Qwen-Image.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help="Classifier-free guidance scale.", + ) + parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") + parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") + parser.add_argument( + "--output", + type=str, + default="qwen_image_output.png", + help="Path to save the generated image (PNG).", + ) + parser.add_argument( + "--num_images_per_prompt", + type=int, + default=1, + help="Number of images to generate for the given prompt.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Number of denoising steps for the diffusion sampler.", + ) + parser.add_argument( + "--cache_backend", + type=str, + default=None, + choices=["cache_dit", "tea_cache"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). " + "Default: None (no cache acceleration)." + ), + ) + parser.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=1, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) + parser.add_argument( + "--enable-cpu-offload", + action="store_true", + help="Enable CPU offloading for diffusion models.", + ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", + ) + parser.add_argument( + "--vae_use_slicing", + action="store_true", + help="Enable VAE slicing for memory optimization.", + ) + parser.add_argument( + "--vae_use_tiling", + action="store_true", + help="Enable VAE tiling for memory optimization.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + + # Configure cache based on backend type + cache_config = None + if args.cache_backend == "cache_dit": + # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer + # All parameters marked with [cache-dit only] in DiffusionCacheConfig + cache_config = { + # DBCache parameters [cache-dit only] + "Fn_compute_blocks": 1, # Optimized for single-transformer models + "Bn_compute_blocks": 0, # Number of backward compute blocks + "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models) + "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching + "max_continuous_cached_steps": 3, # Limit to prevent precision degradation + # TaylorSeer parameters [cache-dit only] + "enable_taylorseer": False, # Disabled by default (not suitable for few-step models) + "taylorseer_order": 1, # TaylorSeer polynomial order + # SCM (Step Computation Masking) parameters [cache-dit only] + "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" + "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" + } + elif args.cache_backend == "tea_cache": + # TeaCache configuration + # All parameters marked with [tea_cache only] in DiffusionCacheConfig + cache_config = { + # TeaCache parameters [tea_cache only] + "rel_l1_thresh": 0.2, # Threshold for accumulated relative L1 distance + # Note: coefficients will use model-specific defaults based on model_type + # (e.g., QwenImagePipeline or FluxPipeline) + } + + # assert args.ring_degree == 1, "Ring attention is not supported yet" + parallel_config = DiffusionParallelConfig( + ulysses_degree=args.ulysses_degree, + ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, + ) + + # Check if profiling is requested via environment variable + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + + omni = Omni( + model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, + vae_use_slicing=args.vae_use_slicing, + vae_use_tiling=args.vae_use_tiling, + cache_backend=args.cache_backend, + cache_config=cache_config, + enable_cache_dit_summary=args.enable_cache_dit_summary, + parallel_config=parallel_config, + enforce_eager=args.enforce_eager, + enable_cpu_offload=args.enable_cpu_offload, + ) + + if profiler_enabled: + print("[Profiler] Starting profiling...") + omni.start_profile() + + # Time profiling for generation + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + print( + f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, " + f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + ) + print(f" Image size: {args.width}x{args.height}") + print(f"{'=' * 60}\n") + + generation_start = time.perf_counter() + outputs = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + }, + OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + generator=generator, + true_cfg_scale=args.cfg_scale, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + num_outputs_per_prompt=args.num_images_per_prompt, + ), + ) + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + # Print profiling results + print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + if profile_results and isinstance(profile_results, dict): + traces = profile_results.get("traces", []) + print("\n" + "=" * 60) + print("PROFILING RESULTS:") + for rank, trace in enumerate(traces): + print(f"\nRank {rank}:") + if trace: + print(f" • Trace: {trace}") + if not traces: + print(" No traces collected.") + print("=" * 60) + else: + print("[Profiler] No valid profiling data returned.") + + # Extract images from OmniRequestOutput + # omni.generate() returns list[OmniRequestOutput], extract images from the first output + if not outputs or len(outputs) == 0: + raise ValueError("No output generated from omni.generate()") + logger.info(f"Outputs: {outputs}") + + # Extract images from request_output[0]['images'] + first_output = outputs[0] + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out.images + if not images: + raise ValueError("No images found in request_output") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".png" + stem = output_path.stem or "qwen_image_output" + if len(images) <= 1: + images[0].save(output_path) + print(f"Saved generated image to {output_path}") + else: + for idx, img in enumerate(images): + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + img.save(save_path) + print(f"Saved generated image to {save_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/text_to_video/text_to_video.md b/examples/offline_inference/text_to_video/text_to_video.md new file mode 100644 index 0000000000000000000000000000000000000000..04f1a2653bb1c1bb24f6466d34be85aaf2a8f5ea --- /dev/null +++ b/examples/offline_inference/text_to_video/text_to_video.md @@ -0,0 +1,37 @@ +# Text-To-Video + +The `Wan-AI/Wan2.2-T2V-A14B-Diffusers` pipeline generates short videos from text prompts. + +## Local CLI Usage + +```bash +python text_to_video.py \ + --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --negative_prompt "<optional quality filter>" \ + --height 480 \ + --width 832 \ + --num_frames 33 \ + --guidance_scale 4.0 \ + --guidance_scale_high 3.0 \ + --flow_shift 12.0 \ + --num_inference_steps 40 \ + --fps 16 \ + --output t2v_out.mp4 +``` + +Key arguments: + +- `--prompt`: text description (string). +- `--height/--width`: output resolution (defaults 480x832, i.e. 480P). Dimensions should align with Wan VAE downsampling (multiples of 8). +- `--num_frames`: Number of frames (Wan default is 81). +- `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high). +- `--negative_prompt`: optional list of artifacts to suppress (the PR demo used a long Chinese string). +- `--boundary_ratio`: Boundary split ratio for low/high DiT. Default `0.875` uses both transformers for best quality. Set to `1.0` to load only the low-noise transformer (saves noticeable memory with good quality, recommended if memory is limited). Set to `0.0` loads only the high-noise transformer (not recommended, lower quality). +- `--fps`: frames per second for the saved MP4 (requires `diffusers` export_to_video). +- `--output`: path to save the generated video. +- `--vae_use_slicing`: enable VAE slicing for memory optimization. +- `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--enable-cpu-offload`: enable CPU offloading for diffusion models. + +> ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..e9dd2d08562b18aa7c8d31e0acfbf76a0ab534b5 --- /dev/null +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import os +import time +from pathlib import Path + +import numpy as np +import torch + +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate a video with Wan2.2 T2V.") + parser.add_argument( + "--model", + default="Wan-AI/Wan2.2-T2V-A14B-Diffusers", + help="Diffusers Wan2.2 model ID or local path.", + ) + parser.add_argument("--prompt", default="A serene lakeside sunrise with mist over the water.", help="Text prompt.") + parser.add_argument("--negative_prompt", default="", help="Negative prompt.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument("--guidance_scale", type=float, default=4.0, help="CFG scale (applied to low/high).") + parser.add_argument("--guidance_scale_high", type=float, default=None, help="Optional separate CFG for high-noise.") + parser.add_argument("--height", type=int, default=720, help="Video height.") + parser.add_argument("--width", type=int, default=1280, help="Video width.") + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames (Wan default is 81).") + parser.add_argument("--num_inference_steps", type=int, default=40, help="Sampling steps.") + parser.add_argument( + "--boundary_ratio", + type=float, + default=0.875, + help="Boundary split ratio for low/high DiT. Default 0.875 uses both transformers for best quality. Set to 1.0 to load only the low-noise transformer (saves noticeable memory with good quality, recommended if memory is limited).", + ) + parser.add_argument( + "--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." + ) + parser.add_argument( + "--cache_backend", + type=str, + default=None, + choices=["cache_dit"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer). " + "Default: None (no cache acceleration)." + ), + ) + parser.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) + parser.add_argument("--output", type=str, default="wan22_output.mp4", help="Path to save the video (mp4).") + parser.add_argument("--fps", type=int, default=24, help="Frames per second for the output video.") + parser.add_argument( + "--vae_use_slicing", + action="store_true", + help="Enable VAE slicing for memory optimization.", + ) + parser.add_argument( + "--vae_use_tiling", + action="store_true", + help="Enable VAE tiling for memory optimization.", + ) + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) + parser.add_argument( + "--enable-cpu-offload", + action="store_true", + help="Enable CPU offloading for diffusion models.", + ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of ready layers (blocks) to keep on GPU during generation.", + ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=1, + help="Number of GPUs used for ulysses sequence parallelism.", + ) + parser.add_argument( + "--ring_degree", + type=int, + default=1, + help="Number of GPUs used for ring sequence parallelism.", + ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + + # Wan2.2 cache-dit tuning (from cache-dit examples and cache_alignment). + cache_config = None + if args.cache_backend == "cache_dit": + cache_config = { + # DBCache parameters [cache-dit only] + "Fn_compute_blocks": 1, # Optimized for single-transformer models + "Bn_compute_blocks": 0, # Number of backward compute blocks + "max_warmup_steps": 4, # Maximum warmup steps (works for few-step models) + "max_cached_steps": 20, + "residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching + "max_continuous_cached_steps": 3, # Limit to prevent precision degradation + # TaylorSeer parameters [cache-dit only] + "enable_taylorseer": False, # Disabled by default (not suitable for few-step models) + "taylorseer_order": 1, # TaylorSeer polynomial order + # SCM (Step Computation Masking) parameters [cache-dit only] + "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" + "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" + } + # Configure parallel settings (only SP is supported for Wan) + # Note: cfg_parallel and tensor_parallel are not implemented for Wan models + parallel_config = DiffusionParallelConfig( + ulysses_degree=args.ulysses_degree, + ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, + ) + + # Check if profiling is requested via environment variable + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + + omni = Omni( + model=args.model, + enable_layerwise_offload=args.enable_layerwise_offload, + layerwise_num_gpu_layers=args.layerwise_num_gpu_layers, + vae_use_slicing=args.vae_use_slicing, + vae_use_tiling=args.vae_use_tiling, + boundary_ratio=args.boundary_ratio, + flow_shift=args.flow_shift, + cache_backend=args.cache_backend, + cache_config=cache_config, + enable_cache_dit_summary=args.enable_cache_dit_summary, + enable_cpu_offload=args.enable_cpu_offload, + parallel_config=parallel_config, + enforce_eager=args.enforce_eager, + ) + + if profiler_enabled: + print("[Profiler] Starting profiling...") + omni.start_profile() + + # Print generation configuration + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Frames: {args.num_frames}") + print( + f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + ) + print(f" Video size: {args.width}x{args.height}") + print(f"{'=' * 60}\n") + + generation_start = time.perf_counter() + frames = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + }, + OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + generator=generator, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_high, + num_inference_steps=args.num_inference_steps, + num_frames=args.num_frames, + ), + ) + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + # Print profiling results + print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + + # Extract video frames from OmniRequestOutput + if isinstance(frames, list) and len(frames) > 0: + first_item = frames[0] + + # Check if it's an OmniRequestOutput + if hasattr(first_item, "final_output_type"): + if first_item.final_output_type != "image": + raise ValueError( + f"Unexpected output type '{first_item.final_output_type}', expected 'image' for video generation." + ) + + # Pipeline mode: extract from nested request_output + if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: + if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: + inner_output = first_item.request_output[0] + if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): + frames = inner_output.images[0] if inner_output.images else None + if frames is None: + raise ValueError("No video frames found in output.") + # Diffusion mode: use direct images field + elif hasattr(first_item, "images") and first_item.images: + frames = first_item.images + else: + raise ValueError("No video frames found in OmniRequestOutput.") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + try: + from diffusers.utils import export_to_video + except ImportError: + raise ImportError("diffusers is required for export_to_video.") + + # frames may be np.ndarray (preferred) or torch.Tensor + # export_to_video expects a list of frames with values in [0, 1] + if isinstance(frames, torch.Tensor): + video_tensor = frames.detach().cpu() + if video_tensor.dim() == 5: + # [B, C, F, H, W] or [B, F, H, W, C] + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + video_tensor = video_tensor.permute(1, 2, 3, 0) + # If float, assume [-1,1] and normalize to [0,1] + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + video_array = video_tensor.float().numpy() + else: + video_array = frames + if hasattr(video_array, "shape") and video_array.ndim == 5: + video_array = video_array[0] + + # Convert 4D array (frames, H, W, C) to list of frames for export_to_video + if isinstance(video_array, np.ndarray) and video_array.ndim == 4: + video_array = list(video_array) + + export_to_video(video_array, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + if profile_results and isinstance(profile_results, dict): + traces = profile_results.get("traces", []) + print("\n" + "=" * 60) + print("PROFILING RESULTS:") + for rank, trace in enumerate(traces): + print(f"\nRank {rank}:") + if trace: + print(f" • Trace: {trace}") + if not traces: + print(" No traces collected.") + print("=" * 60) + else: + print("[Profiler] No valid profiling data returned.") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/bagel/README.md b/examples/online_serving/bagel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3fbea0550b57be673e71aad1b987efd746187573 --- /dev/null +++ b/examples/online_serving/bagel/README.md @@ -0,0 +1,230 @@ +# BAGEL-7B-MoT + +## 🛠️ Installation + +Please refer to [README.md](../../../README.md) + +## Run examples (BAGEL-7B-MoT) + +**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices. + +### Launch the Server + +```bash +# Use default configuration +vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 +``` + +Or use the convenience script: + +```bash +cd /workspace/vllm-omni/examples/online_serving/bagel +bash run_server.sh +``` + +If you have a custom stage configs file, launch the server with the command below: + +```bash +vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +Get into the bagel folder: + +```bash +cd examples/online_serving/bagel +``` + +Send request via Python + +```bash +python openai_chat_client.py --prompt "A cute cat" --modality text2img +``` + +The Python client supports the following command-line arguments: + +- `--prompt` (or `-p`): Text prompt for generation (default: `A cute cat`) +- `--output` (or `-o`): Output file path for image results (default: `bagel_output.png`) +- `--server` (or `-s`): Server URL (default: `http://localhost:8091`) +- `--image-url` (or `-i`): Input image URL or local file path (for img2img/img2text modes) +- `--modality` (or `-m`): Task modality (default: `text2img`). Options: `text2img`, `img2img`, `img2text`, `text2text` +- `--height`: Image height in pixels (default: 512) +- `--width`: Image width in pixels (default: 512) +- `--steps`: Number of inference steps (default: 25) +- `--seed`: Random seed (default: 42) +- `--negative`: Negative prompt for image generation + +Example with custom parameters: + +```bash +python openai_chat_client.py \ + --prompt "A futuristic city" \ + --modality text2img \ + --height 768 \ + --width 768 \ + --steps 50 \ + --seed 42 \ + --negative "blurry, low quality" +``` + +## Modality Control + +BAGEL-7B-MoT supports **multiple modality modes** for different use cases. + +The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml) + +| Modality | Input | Output | Description | +| ----------- | ------------ | ------ | -------------------------------------- | +| `text2img` | Text | Image | Generate images from text prompts | +| `img2img` | Image + Text | Image | Transform images using text guidance | +| `img2text` | Image + Text | Text | Generate text descriptions from images | +| `text2text` | Text | Text | Pure text generation | + +### Text to Image (text2img) + +Generate images from text prompts: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "A beautiful sunset over mountains" \ + --modality text2img \ + --output sunset.png \ + --steps 50 +``` + +**Using curl** + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>A beautiful sunset over mountains<|im_end|>"}]}], + "modalities": ["image"], + "height": 512, + "width": 512, + "num_inference_steps": 50, + "seed": 42 + }' +``` + + +### Image to Image (img2img) + +Transform images based on text prompts: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "Make the cat stand up" \ + --modality img2img \ + --image-url /path/to/input.jpg \ + --output transformed.png +``` + +**Using curl** + +```bash +IMAGE_BASE64=$(base64 -w 0 cat.jpg) + +cat <<EOF > payload.json +{ + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "<|im_start|>Make the cat stand up<|im_end|>"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,${IMAGE_BASE64}"}} + ] + }], + "modalities": ["image"], + "height": 512, + "width": 512, + "num_inference_steps": 50, + "seed": 42 +} +EOF + +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @payload.json + +``` + +### Image to Text (img2text) + +Generate text descriptions from images: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "Describe this image in detail" \ + --modality img2text \ + --image-url /path/to/image.jpg +``` + +**Using curl** + +```bash +IMAGE_BASE64=$(base64 -w 0 cat.jpg) + +cat <<EOF > payload.json +{ + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "<|im_start|>user\n<|image_pad|>\nDescribe this image in detail<|im_end|>\n<|im_start|>assistant\n"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,${IMAGE_BASE64}"}} + ] + }], + "modalities": ["text"] +} +EOF + +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @payload.json +``` + +### Text to Text (text2text) + +Pure text generation: + +**Using Python client** + +```bash +python openai_chat_client.py \ + --prompt "What is the capital of France?" \ + --modality text2text +``` + +**Using curl** + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"}]}] + "modalities": ["text"] + }' +``` + +## FAQ + +- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. + +```bash +sudo apt update +sudo apt install ffmpeg +``` + +- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. + +| Stage | VRAM | +| :------------------ | :--------------------------- | +| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** | +| Stage-1 (DiT) | **26.50 GiB** | +| Total | **~42 GiB + KV Cache** | diff --git a/examples/online_serving/bagel/openai_chat_client.py b/examples/online_serving/bagel/openai_chat_client.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f4cac5d7a8835cf8c117ca08b919387f84ddb --- /dev/null +++ b/examples/online_serving/bagel/openai_chat_client.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +Bagel OpenAI-compatible chat client for image generation and multimodal tasks. + +Usage: + python openai_chat_client.py --prompt "A cute cat" --output output.png + python openai_chat_client.py --prompt "Describe this image" --image-url https://example.com/image.png +""" + +import argparse +import base64 +from pathlib import Path + +import requests + + +def generate_image( + prompt: str, + server_url: str = "http://localhost:8091", + image_url: str | None = None, + height: int | None = None, + width: int | None = None, + steps: int | None = None, + seed: int | None = None, + negative_prompt: str | None = None, + modality: str = "text2img", # "text2img" (default), "img2img", "img2text", "text2text" +) -> bytes | str | None: + """Generate an image or text using the chat completions API. + + Args: + prompt: Text description or prompt + server_url: Server URL + image_url: URL or path to input image (for img2img/img2text) + height: Image height in pixels + width: Image width in pixels + steps: Number of inference steps + seed: Random seed + negative_prompt: Negative prompt + modality: Task modality hint + + Returns: + Image bytes (for image outputs) or Text string (for text outputs) or None if failed + """ + + # Construct Message Content + content = [{"type": "text", "text": f"<|im_start|>{prompt}<|im_end|>"}] + + if image_url: + # Check if local file + if Path(image_url).exists(): + with open(image_url, "rb") as f: + b64_data = base64.b64encode(f.read()).decode("utf-8") + final_image_url = f"data:image/jpeg;base64,{b64_data}" + else: + final_image_url = image_url + + content.append({"type": "image_url", "image_url": {"url": final_image_url}}) + + messages = [{"role": "user", "content": content}] + + # Build request payload with all parameters at top level + # Note: vLLM ignores "extra_body", so we put parameters directly in the payload + payload = {"messages": messages} + + # Set output modalities at top level + if modality == "text2img" or modality == "img2img": + payload["modalities"] = ["image"] + elif modality == "img2text" or modality == "text2text": + payload["modalities"] = ["text"] + + # Add generation parameters directly to payload + if height is not None: + payload["height"] = height + if width is not None: + payload["width"] = width + if steps is not None: + payload["num_inference_steps"] = steps + if seed is not None: + payload["seed"] = seed + if negative_prompt: + payload["negative_prompt"] = negative_prompt + + # Send request + try: + print(f"Sending request to {server_url} with modality {modality}...") + response = requests.post( + f"{server_url}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=300, + ) + response.raise_for_status() + data = response.json() + + # Extract content - check ALL choices since server may return multiple + # (e.g., text in choices[0], image in choices[1]) + choices = data.get("choices", []) + + # First pass: look for image output in any choice + for choice in choices: + choice_content = choice.get("message", {}).get("content") + + # Handle Image Output + if isinstance(choice_content, list) and len(choice_content) > 0: + first_item = choice_content[0] + if isinstance(first_item, dict) and "image_url" in first_item: + img_url_str = first_item["image_url"].get("url", "") + if img_url_str.startswith("data:image"): + _, b64_data = img_url_str.split(",", 1) + return base64.b64decode(b64_data) + + # Second pass: look for text output if no image found + for choice in choices: + choice_content = choice.get("message", {}).get("content") + if isinstance(choice_content, str) and choice_content: + return choice_content + + print(f"Unexpected response format: {choices}") + return None + + except Exception as e: + print(f"Error: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser(description="Bagel multimodal chat client") + parser.add_argument("--prompt", "-p", default="<|im_start|>A cute cat<|im_end|>", help="Text prompt") + parser.add_argument("--output", "-o", default="bagel_output.png", help="Output file (for image results)") + parser.add_argument("--server", "-s", default="http://localhost:8091", help="Server URL") + + # Modality Control + parser.add_argument("--image-url", "-i", type=str, help="Input image URL or local path") + parser.add_argument( + "--modality", + "-m", + default="text2img", + choices=["text2img", "img2img", "img2text", "text2text"], + help="Task modality", + ) + + # Generation Params + parser.add_argument("--height", type=int, default=512, help="Image height") + parser.add_argument("--width", type=int, default=512, help="Image width") + parser.add_argument("--steps", type=int, default=25, help="Inference steps") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--negative", help="Negative prompt") + + args = parser.parse_args() + + print(f"Mode: {args.modality}") + if args.image_url: + print(f"Input Image: {args.image_url}") + + result = generate_image( + prompt=args.prompt, + server_url=args.server, + image_url=args.image_url, + height=args.height, + width=args.width, + steps=args.steps, + seed=args.seed, + negative_prompt=args.negative, + modality=args.modality, + ) + + if result: + if isinstance(result, bytes): + # It's an image + output_path = Path(args.output) + output_path.write_bytes(result) + print(f"Image saved to: {output_path}") + print(f"Size: {len(result) / 1024:.1f} KB") + elif isinstance(result, str): + # It's text + print("Response:") + print(result) + else: + print("Failed to generate response") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/bagel/run_server.sh b/examples/online_serving/bagel/run_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..a64057ef033ea3f8a23c94899b53583fee0a2a2e --- /dev/null +++ b/examples/online_serving/bagel/run_server.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Bagel online serving startup script + +MODEL="${MODEL:-ByteDance-Seed/BAGEL-7B-MoT}" +PORT="${PORT:-8091}" + +echo "Starting Bagel server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +vllm serve "$MODEL" --omni \ + --port "$PORT" diff --git a/examples/online_serving/image_to_image/README.md b/examples/online_serving/image_to_image/README.md new file mode 100644 index 0000000000000000000000000000000000000000..171a3368043655710427313054684fedc4e92df4 --- /dev/null +++ b/examples/online_serving/image_to_image/README.md @@ -0,0 +1,218 @@ +# Image-To-Image + +This example demonstrates how to deploy Qwen-Image-Edit model for online image editing service using vLLM-Omni. + +For **multi-image** input editing, use **Qwen-Image-Edit-2509** (QwenImageEditPlusPipeline) and send multiple images in the user message content. + +## Start Server + +### Basic Start + +```bash +vllm serve Qwen/Qwen-Image-Edit --omni --port 8092 +``` + +### Multi-Image Edit (Qwen-Image-Edit-2509) + +```bash +vllm serve Qwen/Qwen-Image-Edit-2509 --omni --port 8092 +``` + +### Start with Parameters + + +Or use the startup script: + +```bash +bash run_server.sh +``` + +To serve Qwen-Image-Edit-2509 with the script: + +```bash +MODEL=Qwen/Qwen-Image-Edit-2509 bash run_server.sh +``` + +## API Calls + +### Method 1: Using curl (Image Editing) + +```bash +# Image editing +bash run_curl_image_edit.sh input.png "Convert this image to watercolor style" + +# Or execute directly +IMG_B64=$(base64 -w0 input.png) + +cat <<EOF > request.json +{ + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Convert this image to watercolor style"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,$IMG_B64"}} + ] + }], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 1, + "seed": 42 + } +} +EOF + +curl -s http://localhost:8092/v1/chat/completions -H "Content-Type: application/json" -d @request.json | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2 | base64 -d > output.png +``` + +### Method 2: Using Python Client + +```bash +python openai_chat_client.py --input input.png --prompt "Convert to oil painting style" --output output.png + +# Multi-image editing (Qwen-Image-Edit-2509 server required) +python openai_chat_client.py --input input1.png input2.png --prompt "Combine these images into a single scene" --output output.png +``` + +### Method 3: Using Gradio Demo + +```bash +python gradio_demo.py +# Visit http://localhost:7861 +``` + +## Request Format + +### Image Editing (Using image_url Format) + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Convert this image to watercolor style"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} + ] + } + ] +} +``` + +### Image Editing (Using Simplified image Format) + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"text": "Convert this image to watercolor style"}, + {"image": "BASE64_IMAGE_DATA"} + ] + } + ] +} +``` + +### Image Editing with Parameters + +Use `extra_body` to pass generation parameters: + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Convert to ink wash painting style"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} + ] + } + ], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 7.5, + "seed": 42 + } +} +``` + +### Multi-Image Editing (Qwen-Image-Edit-2509) + +Provide multiple images in `content` (order matters): + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Combine these images into a single scene"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."} }, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."} } + ] + } + ] +} +``` + +## Generation Parameters (extra_body) + +| Parameter | Type | Default | Description | +| ------------------------ | ----- | ------- | ------------------------------------- | +| `height` | int | None | Output image height in pixels | +| `width` | int | None | Output image width in pixels | +| `size` | str | None | Output image size (e.g., "1024x1024") | +| `num_inference_steps` | int | 50 | Number of denoising steps | +| `guidance_scale` | float | 7.5 | CFG guidance scale | +| `seed` | int | None | Random seed (reproducible) | +| `negative_prompt` | str | None | Negative prompt | +| `num_outputs_per_prompt` | int | 1 | Number of images to generate | + +## Response Format + +```json +{ + "id": "chatcmpl-xxx", + "created": 1234567890, + "model": "Qwen/Qwen-Image-Edit", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": [{ + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,..." + } + }] + }, + "finish_reason": "stop" + }], + "usage": {...} +} +``` + +## Common Editing Instructions Examples + +| Instruction | Description | +| ---------------------------------------- | ---------------- | +| `Convert this image to watercolor style` | Style transfer | +| `Convert the image to black and white` | Desaturation | +| `Enhance the color saturation` | Color adjustment | +| `Convert to cartoon style` | Cartoonization | +| `Add vintage filter effect` | Filter effect | +| `Convert daytime scene to nighttime` | Scene conversion | + +## File Description + +| File | Description | +| ------------------------ | ---------------------------- | +| `run_server.sh` | Server startup script | +| `run_curl_image_edit.sh` | curl image editing example | +| `openai_chat_client.py` | Python client | +| `gradio_demo.py` | Gradio interactive interface | diff --git a/examples/online_serving/image_to_image/gradio_demo.py b/examples/online_serving/image_to_image/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..8cad48279576be10869d0aa3e95546059abda79e --- /dev/null +++ b/examples/online_serving/image_to_image/gradio_demo.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +""" +Qwen-Image-Edit Gradio Demo for online serving. + +Usage: + python gradio_demo.py [--server http://localhost:8092] [--port 7861] +""" + +import argparse +import base64 +from io import BytesIO + +import gradio as gr +import requests +from PIL import Image + + +def _pil_to_b64_png(img: Image.Image) -> str: + buffer = BytesIO() + img.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +def edit_image( + input_image: Image.Image, + extra_images: list[str] | None, + prompt: str, + steps: int, + guidance_scale: float, + seed: int | None, + negative_prompt: str, + server_url: str, +) -> Image.Image | None: + """Edit an image using the chat completions API.""" + if input_image is None: + raise gr.Error("Please upload an image first") + + images: list[Image.Image] = [input_image] + if extra_images: + for p in extra_images: + try: + images.append(Image.open(p).convert("RGB")) + except Exception as e: + raise gr.Error(f"Failed to open image: {p}. Error: {e}") from e + + # Build user message with text and image + content: list[dict[str, object]] = [{"type": "text", "text": prompt}] + for img in images: + content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{_pil_to_b64_png(img)}"}}) + + messages = [ + { + "role": "user", + "content": content, + } + ] + + # Build extra_body with generation parameters + extra_body = { + "num_inference_steps": steps, + "guidance_scale": guidance_scale, + } + if seed is not None and seed >= 0: + extra_body["seed"] = seed + if negative_prompt: + extra_body["negative_prompt"] = negative_prompt + + # Build request payload + payload = {"messages": messages, "extra_body": extra_body} + + try: + response = requests.post( + f"{server_url}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=300, + ) + response.raise_for_status() + data = response.json() + + content = data["choices"][0]["message"]["content"] + if isinstance(content, list) and len(content) > 0: + image_url = content[0].get("image_url", {}).get("url", "") + if image_url.startswith("data:image"): + _, b64_data = image_url.split(",", 1) + image_bytes = base64.b64decode(b64_data) + return Image.open(BytesIO(image_bytes)) + + return None + + except Exception as e: + print(f"Error: {e}") + raise gr.Error(f"Edit failed: {e}") + + +def create_demo(server_url: str): + """Create Gradio demo interface.""" + + with gr.Blocks(title="Qwen-Image-Edit Demo") as demo: + gr.Markdown("# Qwen-Image-Edit Online Editing") + gr.Markdown( + "Upload an image and describe the editing effect you want. " + "For multi-image editing, upload extra images (requires Qwen-Image-Edit-2509 server)." + ) + + with gr.Row(): + with gr.Column(scale=1): + input_image = gr.Image( + label="Input Image", + type="pil", + ) + extra_images = gr.File( + label="Additional Images (Optional)", + file_count="multiple", + type="filepath", + ) + prompt = gr.Textbox( + label="Edit Instruction", + placeholder="Describe the editing effect you want...", + lines=2, + ) + negative_prompt = gr.Textbox( + label="Negative Prompt", + placeholder="Describe what you don't want...", + lines=2, + ) + + with gr.Row(): + steps = gr.Slider( + label="Inference Steps", + minimum=10, + maximum=100, + value=50, + step=5, + ) + guidance_scale = gr.Slider( + label="Guidance Scale (CFG)", + minimum=1.0, + maximum=20.0, + value=7.5, + step=0.5, + ) + + with gr.Row(): + seed = gr.Number( + label="Random Seed (-1 for random)", + value=-1, + precision=0, + ) + + edit_btn = gr.Button("Edit Image", variant="primary") + + with gr.Column(scale=1): + output_image = gr.Image( + label="Edited Image", + type="pil", + ) + + # Examples + gr.Examples( + examples=[ + ["Convert this image to watercolor style"], + ["Convert the image to black and white"], + ["Enhance the color saturation"], + ["Convert to cartoon style"], + ["Add vintage filter effect"], + ["Convert daytime to nighttime"], + ["Convert to oil painting style"], + ["Add dreamy blur effect"], + ], + inputs=[prompt], + ) + + def process_edit(img, imgs, p, st, g, se, n): + actual_seed = se if se >= 0 else None + return edit_image(img, imgs, p, st, g, actual_seed, n, server_url) + + edit_btn.click( + fn=process_edit, + inputs=[input_image, extra_images, prompt, steps, guidance_scale, seed, negative_prompt], + outputs=[output_image], + ) + + return demo + + +def main(): + parser = argparse.ArgumentParser(description="Qwen-Image-Edit Gradio Demo") + parser.add_argument("--server", default="http://localhost:8092", help="Server URL") + parser.add_argument("--port", type=int, default=7861, help="Gradio port") + parser.add_argument("--share", action="store_true", help="Create public link") + + args = parser.parse_args() + + print(f"Connecting to server: {args.server}") + demo = create_demo(args.server) + demo.launch(server_port=args.port, share=args.share) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/image_to_image/openai_chat_client.py b/examples/online_serving/image_to_image/openai_chat_client.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe4b0edecec789e02ad5fa79d989077f2742bda --- /dev/null +++ b/examples/online_serving/image_to_image/openai_chat_client.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Qwen-Image-Edit OpenAI-compatible chat client for image editing. + +Usage: + python openai_chat_client.py --input qwen_image_output.png --prompt "Convert to watercolor style" --output output.png + python openai_chat_client.py --input input.png --prompt "Convert to oil painting" --seed 42 + python openai_chat_client.py --input input1.png input2.png --prompt "Combine these images into a single scene" +""" + +import argparse +import base64 +from io import BytesIO +from pathlib import Path + +import requests +from PIL import Image + + +def _encode_image_as_data_url(input_path: Path) -> str: + image_bytes = input_path.read_bytes() + try: + img = Image.open(BytesIO(image_bytes)) + mime_type = f"image/{img.format.lower()}" if img.format else "image/png" + except Exception: + mime_type = "image/png" + image_b64 = base64.b64encode(image_bytes).decode("utf-8") + return f"data:{mime_type};base64,{image_b64}" + + +def edit_image( + input_image: str | Path | list[str | Path], + prompt: str, + server_url: str = "http://localhost:8092", + height: int | None = None, + width: int | None = None, + steps: int | None = None, + guidance_scale: float | None = None, + seed: int | None = None, + negative_prompt: str | None = None, +) -> bytes | None: + """Edit an image using the chat completions API. + + Args: + input_image: Path(s) to input image(s). For multi-image editing, pass multiple paths. + prompt: Text description of the edit + server_url: Server URL + height: Output image height in pixels + width: Output image width in pixels + steps: Number of inference steps + guidance_scale: CFG guidance scale + seed: Random seed + negative_prompt: Negative prompt + + Returns: + Edited image bytes or None if failed + """ + input_images = input_image if isinstance(input_image, list) else [input_image] + input_paths = [Path(p) for p in input_images] + for p in input_paths: + if not p.exists(): + print(f"Error: Input image not found: {p}") + return None + + # Build user message with text and image + content: list[dict[str, object]] = [{"type": "text", "text": prompt}] + for p in input_paths: + content.append({"type": "image_url", "image_url": {"url": _encode_image_as_data_url(p)}}) + + messages = [ + { + "role": "user", + "content": content, + } + ] + + # Build extra_body with generation parameters + extra_body = {} + if steps is not None: + extra_body["num_inference_steps"] = steps + if guidance_scale is not None: + extra_body["guidance_scale"] = guidance_scale + if seed is not None: + extra_body["seed"] = seed + if negative_prompt: + extra_body["negative_prompt"] = negative_prompt + + # Build request payload + payload = {"messages": messages} + if extra_body: + payload["extra_body"] = extra_body + + # Send request + try: + response = requests.post( + f"{server_url}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=300, + ) + response.raise_for_status() + data = response.json() + + # Extract image from response + content = data["choices"][0]["message"]["content"] + if isinstance(content, list) and len(content) > 0: + image_url = content[0].get("image_url", {}).get("url", "") + if image_url.startswith("data:image"): + _, b64_data = image_url.split(",", 1) + return base64.b64decode(b64_data) + + print(f"Unexpected response format: {content}") + return None + + except Exception as e: + print(f"Error: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser(description="Qwen-Image-Edit chat client") + parser.add_argument("--input", "-i", required=True, nargs="+", help="Input image path(s)") + parser.add_argument("--prompt", "-p", required=True, help="Edit prompt") + parser.add_argument("--output", "-o", default="output.png", help="Output file") + parser.add_argument("--server", "-s", default="http://localhost:8092", help="Server URL") + parser.add_argument("--height", type=int, default=1024, help="Output image height") + parser.add_argument("--width", type=int, default=1024, help="Output image width") + parser.add_argument("--steps", type=int, default=50, help="Inference steps") + parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--negative", help="Negative prompt") + + args = parser.parse_args() + + if len(args.input) == 1: + print(f"Input: {args.input[0]}") + else: + print(f"Inputs ({len(args.input)}): {', '.join(args.input)}") + print(f"Prompt: {args.prompt}") + + image_bytes = edit_image( + input_image=args.input, + prompt=args.prompt, + server_url=args.server, + height=args.height, + width=args.width, + steps=args.steps, + guidance_scale=args.guidance, + seed=args.seed, + negative_prompt=args.negative, + ) + + if image_bytes: + output_path = Path(args.output) + output_path.write_bytes(image_bytes) + print(f"Image saved to: {output_path}") + print(f"Size: {len(image_bytes) / 1024:.1f} KB") + else: + print("Failed to edit image") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/image_to_image/run_curl_image_edit.sh b/examples/online_serving/image_to_image/run_curl_image_edit.sh new file mode 100644 index 0000000000000000000000000000000000000000..748a0ebe545cf44a7f716321fc30a1f75a895d4d --- /dev/null +++ b/examples/online_serving/image_to_image/run_curl_image_edit.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Qwen-Image image-edit (image-to-image) curl example + +set -euo pipefail + +if [[ $# -lt 2 ]]; then + echo "Usage: $0 <input_image> \"<edit_prompt>\" [output_file]" >&2 + exit 1 +fi + +INPUT_IMG=$1 +PROMPT=$2 +SERVER="${SERVER:-http://localhost:8092}" +CURRENT_TIME=$(date +%Y%m%d%H%M%S) +OUTPUT="${3:-image_edit_${CURRENT_TIME}.png}" + +if [[ ! -f "$INPUT_IMG" ]]; then + echo "Input image not found: $INPUT_IMG" >&2 + exit 1 +fi + +IMG_B64=$(base64 -w0 "$INPUT_IMG") + +REQUEST_JSON=$( + jq -n --arg prompt "$PROMPT" --arg img "$IMG_B64" '{ + messages: [{ + role: "user", + content: [ + {"type": "text", "text": $prompt}, + {"type": "image_url", "image_url": {"url": ("data:image/png;base64," + $img)}} + ] + }], + extra_body: { + height: 1024, + width: 1024, + num_inference_steps: 50, + guidance_scale: 1, + seed: 42 + } + }' +) + +echo "Generating edited image..." +echo "Server: $SERVER" +echo "Prompt: $PROMPT" +echo "Input : $INPUT_IMG" +echo "Output: $OUTPUT" + +curl -s "$SERVER/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "$REQUEST_JSON" \ + | jq -r '.choices[0].message.content[0].image_url.url' \ + | cut -d',' -f2 \ + | base64 -d > "$OUTPUT" + +if [[ -f "$OUTPUT" ]]; then + echo "Image saved to: $OUTPUT" + echo "Size: $(du -h "$OUTPUT" | cut -f1)" +else + echo "Failed to generate image" + exit 1 +fi diff --git a/examples/online_serving/image_to_image/run_server.sh b/examples/online_serving/image_to_image/run_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b8e081d91fd21815e2fcc581b97efdfd1d79af0 --- /dev/null +++ b/examples/online_serving/image_to_image/run_server.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Qwen-Image-Edit online serving startup script + +MODEL="${MODEL:-Qwen/Qwen-Image-Edit}" +PORT="${PORT:-8092}" + +echo "Starting Qwen-Image-Edit server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +vllm serve "$MODEL" --omni \ + --port "$PORT" diff --git a/examples/online_serving/lora_inference/README.md b/examples/online_serving/lora_inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..16ce55313ddfd300f6b55bb72a4bcbd3e2a7e34d --- /dev/null +++ b/examples/online_serving/lora_inference/README.md @@ -0,0 +1,54 @@ +# Online LoRA Inference (Diffusion) + +This example shows how to use **per-request LoRA** with vLLM-Omni diffusion models via the OpenAI-compatible Chat Completions API. + +> Note: The LoRA adapter path must be readable on the **server** machine (usually a local path or a mounted directory). +> Note: This example uses `/v1/chat/completions`. LoRA payloads for other OpenAI endpoints are not implemented here. + +## Start Server + +```bash +# Pick a diffusion model (examples) +# export MODEL=stabilityai/stable-diffusion-3.5-medium +# export MODEL=Qwen/Qwen-Image + +bash run_server.sh +``` + +## Call API (curl) + +```bash +# Required: local LoRA folder on the server +export LORA_PATH=/path/to/lora_adapter + +# Optional +export SERVER=http://localhost:8091 +export PROMPT="A piece of cheesecake" +export LORA_NAME=my_lora +export LORA_SCALE=1.0 +# Optional: if omitted, the server derives a stable id from LORA_PATH. +# export LORA_INT_ID=123 + +bash run_curl_lora_inference.sh +``` + +## Call API (Python) + +```bash +python openai_chat_client.py \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora_adapter \ + --lora-name my_lora \ + --lora-scale 1.0 \ + --output output.png +``` + +## LoRA Format + +LoRA adapters should be in PEFT format, for example: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` diff --git a/examples/online_serving/lora_inference/openai_chat_client.py b/examples/online_serving/lora_inference/openai_chat_client.py new file mode 100644 index 0000000000000000000000000000000000000000..e24d2fdf65bd59fe134391c17d5dee207f98c3b3 --- /dev/null +++ b/examples/online_serving/lora_inference/openai_chat_client.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +""" +OpenAI-compatible chat client for diffusion LoRA inference. + +Example: + python openai_chat_client.py \ + --server http://localhost:8091 \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora_adapter \ + --lora-name my_lora \ + --lora-scale 1.0 \ + --output output.png +""" + +import argparse +import base64 +from pathlib import Path + +import requests + + +def generate_image( + *, + prompt: str, + server_url: str, + height: int | None, + width: int | None, + num_inference_steps: int | None, + seed: int | None, + lora_name: str | None, + lora_path: str | None, + lora_scale: float | None, + lora_int_id: int | None, +) -> bytes | None: + messages = [{"role": "user", "content": prompt}] + + extra_body: dict = {} + if height is not None: + extra_body["height"] = height + if width is not None: + extra_body["width"] = width + if num_inference_steps is not None: + extra_body["num_inference_steps"] = num_inference_steps + if seed is not None: + extra_body["seed"] = seed + + if lora_path: + lora_body: dict = { + "local_path": lora_path, + "name": lora_name or Path(lora_path).stem, + } + if lora_scale is not None: + lora_body["scale"] = float(lora_scale) + if lora_int_id is not None: + lora_body["int_id"] = int(lora_int_id) + extra_body["lora"] = lora_body + + payload = {"messages": messages} + if extra_body: + payload["extra_body"] = extra_body + + response = requests.post( + f"{server_url}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=300, + ) + response.raise_for_status() + data = response.json() + + content = data["choices"][0]["message"]["content"] + if isinstance(content, list) and content: + image_url = content[0].get("image_url", {}).get("url", "") + if image_url.startswith("data:image"): + _, b64_data = image_url.split(",", 1) + return base64.b64decode(b64_data) + + raise RuntimeError(f"Unexpected response format: {content!r}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Diffusion LoRA OpenAI chat client") + parser.add_argument("--server", default="http://localhost:8091", help="Server URL") + parser.add_argument("--prompt", default="A piece of cheesecake", help="Text prompt") + parser.add_argument("--output", default="lora_online_output.png", help="Output image path") + + parser.add_argument("--height", type=int, default=1024, help="Image height") + parser.add_argument("--width", type=int, default=1024, help="Image width") + parser.add_argument("--steps", type=int, default=50, help="num_inference_steps") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + parser.add_argument("--lora-path", default=None, help="Server-local LoRA adapter folder (PEFT format)") + parser.add_argument("--lora-name", default=None, help="LoRA name (optional)") + parser.add_argument("--lora-scale", type=float, default=1.0, help="LoRA scale") + parser.add_argument( + "--lora-int-id", + type=int, + default=None, + help="LoRA integer id (cache key). If omitted, the server derives a stable id from lora_path.", + ) + + args = parser.parse_args() + + image_bytes = generate_image( + prompt=args.prompt, + server_url=args.server, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + seed=args.seed, + lora_name=args.lora_name, + lora_path=args.lora_path, + lora_scale=args.lora_scale if args.lora_path else None, + lora_int_id=args.lora_int_id if args.lora_path else None, + ) + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_bytes(image_bytes) + print(f"Saved: {out_path} ({len(image_bytes) / 1024:.1f} KiB)") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lora_inference/run_curl_lora_inference.sh b/examples/online_serving/lora_inference/run_curl_lora_inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..14a074fbf876a630a5f2b6331473a6b781bf366e --- /dev/null +++ b/examples/online_serving/lora_inference/run_curl_lora_inference.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Online diffusion LoRA inference via OpenAI-compatible chat API. + +SERVER="${SERVER:-http://localhost:8091}" +PROMPT="${PROMPT:-A piece of cheesecake}" + +LORA_PATH="${LORA_PATH:-}" +LORA_NAME="${LORA_NAME:-lora}" +LORA_SCALE="${LORA_SCALE:-1.0}" +LORA_INT_ID="${LORA_INT_ID:-}" + +HEIGHT="${HEIGHT:-1024}" +WIDTH="${WIDTH:-1024}" +NUM_INFERENCE_STEPS="${NUM_INFERENCE_STEPS:-50}" +SEED="${SEED:-42}" + +CURRENT_TIME=$(date +%Y%m%d%H%M%S) +OUTPUT="${OUTPUT:-lora_online_output_${CURRENT_TIME}.png}" + +if [ -z "$LORA_PATH" ]; then + echo "ERROR: LORA_PATH is required (must be a server-local path)." + exit 1 +fi + +echo "Generating image with LoRA..." +echo "Server: $SERVER" +echo "Prompt: $PROMPT" +echo "LoRA: name=$LORA_NAME id=${LORA_INT_ID:-auto} scale=$LORA_SCALE path=$LORA_PATH" +echo "Output: $OUTPUT" + +LORA_INT_ID_FIELD="" +if [ -n "$LORA_INT_ID" ]; then + LORA_INT_ID_FIELD=", \"int_id\": $LORA_INT_ID" +fi + +curl -s "$SERVER/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{ + \"messages\": [ + {\"role\": \"user\", \"content\": \"$PROMPT\"} + ], + \"extra_body\": { + \"height\": $HEIGHT, + \"width\": $WIDTH, + \"num_inference_steps\": $NUM_INFERENCE_STEPS, + \"seed\": $SEED, + \"lora\": { + \"name\": \"$LORA_NAME\", + \"local_path\": \"$LORA_PATH\", + \"scale\": $LORA_SCALE$LORA_INT_ID_FIELD + } + } + }" | jq -r '.choices[0].message.content[0].image_url.url' | sed 's/^data:image[^,]*,\s*//' | base64 -d > "$OUTPUT" + +if [ -f "$OUTPUT" ]; then + echo "Image saved to: $OUTPUT" + echo "Size: $(du -h "$OUTPUT" | cut -f1)" +else + echo "Failed to generate image" + exit 1 +fi diff --git a/examples/online_serving/lora_inference/run_server.sh b/examples/online_serving/lora_inference/run_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..3233dd77397812e3639df8146445a3f824e72540 --- /dev/null +++ b/examples/online_serving/lora_inference/run_server.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Online diffusion serving with vLLM-Omni (OpenAI-compatible API). + +MODEL="${MODEL:-stabilityai/stable-diffusion-3.5-medium}" +PORT="${PORT:-8091}" + +echo "Starting vLLM-Omni diffusion server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +if [ -z "${VLLM_BIN:-}" ]; then + if command -v vllm-omni >/dev/null 2>&1; then + VLLM_BIN="vllm-omni" + else + VLLM_BIN="vllm" + fi +fi + +"$VLLM_BIN" serve "$MODEL" --omni \ + --port "$PORT" diff --git a/examples/online_serving/qwen2_5_omni/README.md b/examples/online_serving/qwen2_5_omni/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1fac9805a796d2523f688a71945c28c78f41a5b7 --- /dev/null +++ b/examples/online_serving/qwen2_5_omni/README.md @@ -0,0 +1,215 @@ +# Qwen2.5-Omni + +## 🛠️ Installation + +Please refer to [README.md](../../../README.md) + +## Run examples (Qwen2.5-Omni) + +### Launch the Server + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 +``` + +If you have custom stage configs file, launch the server with command below +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +Get into the example folder +```bash +cd examples/online_serving/qwen2_5_omni +``` + +#### Send request via python + +```bash +python openai_chat_completion_client_for_multimodal_generation.py --query-type mixed_modalities +``` + +The Python client supports the following command-line arguments: + +- `--query-type` (or `-q`): Query type (default: `mixed_modalities`). Options: `mixed_modalities`, `use_audio_in_video`, `multi_audios`, `text` +- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type uses video, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4` +- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type uses image, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` +- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` +- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` + + +For example, to use mixed modalities with all local files: + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --video-path /path/to/your/video.mp4 \ + --image-path /path/to/your/image.jpg \ + --audio-path /path/to/your/audio.wav \ + --prompt "Analyze all the media content and provide a comprehensive summary." +``` + +#### Send request via curl + +```bash +bash run_curl_multimodal_generation.sh mixed_modalities +``` + +## Modality control +You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. + +### Supported modalities + +| Modalities | Output | +|------------|--------| +| `["text"]` | Text only | +| `["audio"]` | Text + Audio | +| `["text", "audio"]` | Text + Audio | +| Not specified | Text + Audio (default) | + +### Using curl + +#### Text only + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2.5-Omni-7B", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["text"] + }' +``` + +#### Text + Audio + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2.5-Omni-7B", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["audio"] + }' +``` + +### Using Python client + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --modalities text +``` + +### Using OpenAI Python SDK + +#### Text only + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen2.5-Omni-7B", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["text"] +) +print(response.choices[0].message.content) +``` + +#### Text + Audio + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen2.5-Omni-7B", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["audio"] +) +# Response contains two choices: one with text, one with audio +print(response.choices[0].message.content) # Text response +print(response.choices[1].message.audio) # Audio response +``` + +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --stream +``` + +## Run Local Web UI Demo + +This Web UI demo allows users to interact with the model through a web browser. + +### Running Gradio Demo + +The Gradio demo connects to a vLLM API server. You have two options: + +#### Option 1: One-step Launch Script (Recommended) + +The convenience script launches both the vLLM server and Gradio demo together: + +```bash +./run_gradio_demo.sh --model Qwen/Qwen2.5-Omni-7B --server-port 8091 --gradio-port 7861 +``` + +This script will: +1. Start the vLLM server in the background +2. Wait for the server to be ready +3. Launch the Gradio demo +4. Handle cleanup when you press Ctrl+C + +The script supports the following arguments: +- `--model`: Model name/path (default: Qwen/Qwen2.5-Omni-7B) +- `--server-port`: Port for vLLM server (default: 8091) +- `--gradio-port`: Port for Gradio demo (default: 7861) +- `--stage-configs-path`: Path to custom stage configs YAML file (optional) +- `--server-host`: Host for vLLM server (default: 0.0.0.0) +- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1) +- `--share`: Share Gradio demo publicly (creates a public link) + +#### Option 2: Manual Launch (Two-Step Process) + +**Step 1: Launch the vLLM API server** + +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 +``` + +If you have custom stage configs file: +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +**Step 2: Run the Gradio demo** + +In a separate terminal: + +```bash +python gradio_demo.py --model Qwen/Qwen2.5-Omni-7B --api-base http://localhost:8091/v1 --port 7861 +``` + +Then open `http://localhost:7861/` on your local browser to interact with the web UI. + +The gradio script supports the following arguments: + +- `--model`: Model name/path (should match the server model) +- `--api-base`: Base URL for the vLLM API server (default: http://localhost:8091/v1) +- `--ip`: Host/IP for Gradio server (default: 127.0.0.1) +- `--port`: Port for Gradio server (default: 7861) +- `--share`: Share the Gradio demo publicly (creates a public link) + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` diff --git a/examples/online_serving/qwen2_5_omni/gradio_demo.py b/examples/online_serving/qwen2_5_omni/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d3a67e8ba58810ebef4df470ba2fb143c7b993 --- /dev/null +++ b/examples/online_serving/qwen2_5_omni/gradio_demo.py @@ -0,0 +1,584 @@ +import argparse +import base64 +import io +import os +import random +from pathlib import Path +from typing import Any + +import gradio as gr +import numpy as np +import soundfile as sf +import torch +from openai import OpenAI +from PIL import Image + +SEED = 42 + +SUPPORTED_MODELS: dict[str, dict[str, Any]] = { + "Qwen/Qwen2.5-Omni-7B": { + "sampling_params": { + "thinker": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 2048, + "seed": SEED, + "detokenize": True, + "repetition_penalty": 1.1, + }, + "talker": { + "temperature": 0.9, + "top_p": 0.8, + "top_k": 40, + "max_tokens": 2048, + "seed": SEED, + "detokenize": True, + "repetition_penalty": 1.05, + "stop_token_ids": [8294], + }, + "code2wav": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 2048, + "seed": SEED, + "detokenize": True, + "repetition_penalty": 1.1, + }, + }, + }, +} +# Ensure deterministic behavior across runs. +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.cuda.manual_seed_all(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +os.environ["PYTHONHASHSEED"] = str(SEED) +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +def parse_args(): + parser = argparse.ArgumentParser(description="Gradio demo for Qwen2.5-Omni online inference.") + parser.add_argument( + "--model", + default="Qwen/Qwen2.5-Omni-7B", + help="Model name/path (should match the server model).", + ) + parser.add_argument( + "--api-base", + default="http://localhost:8091/v1", + help="Base URL for the vLLM API server.", + ) + parser.add_argument( + "--ip", + default="127.0.0.1", + help="Host/IP for gradio `launch`.", + ) + parser.add_argument("--port", type=int, default=7861, help="Port for gradio `launch`.") + parser.add_argument("--share", action="store_true", help="Share the Gradio demo publicly.") + return parser.parse_args() + + +def build_sampling_params_dict(seed: int, model_key: str) -> list[dict]: + """Build sampling params as dict for HTTP API mode.""" + model_conf = SUPPORTED_MODELS.get(model_key) + if model_conf is None: + raise ValueError(f"Unsupported model '{model_key}'") + + sampling_templates: dict[str, dict[str, Any]] = model_conf["sampling_params"] + sampling_params: list[dict] = [] + for stage_name, template in sampling_templates.items(): + params = dict(template) + params["seed"] = seed + sampling_params.append(params) + return sampling_params + + +def image_to_base64_data_url(image: Image.Image) -> str: + """Convert PIL Image to base64 data URL.""" + buffered = io.BytesIO() + # Convert to RGB if needed + if image.mode != "RGB": + image = image.convert("RGB") + image.save(buffered, format="JPEG") + img_bytes = buffered.getvalue() + img_b64 = base64.b64encode(img_bytes).decode("utf-8") + return f"data:image/jpeg;base64,{img_b64}" + + +def audio_to_base64_data_url(audio_data: tuple[np.ndarray, int]) -> str: + """Convert audio (numpy array, sample_rate) to base64 data URL.""" + audio_np, sample_rate = audio_data + # Convert to int16 format for WAV + if audio_np.dtype != np.int16: + # Normalize to [-1, 1] range if needed + if audio_np.dtype == np.float32 or audio_np.dtype == np.float64: + audio_np = np.clip(audio_np, -1.0, 1.0) + audio_np = (audio_np * 32767).astype(np.int16) + else: + audio_np = audio_np.astype(np.int16) + + # Write to WAV bytes + buffered = io.BytesIO() + sf.write(buffered, audio_np, sample_rate, format="WAV") + wav_bytes = buffered.getvalue() + wav_b64 = base64.b64encode(wav_bytes).decode("utf-8") + return f"data:audio/wav;base64,{wav_b64}" + + +def video_to_base64_data_url(video_file: str) -> str: + """Convert video file to base64 data URL.""" + video_path = Path(video_file) + if not video_path.exists(): + raise FileNotFoundError(f"Video file not found: {video_file}") + + # Detect MIME type from extension + video_path_lower = str(video_path).lower() + if video_path_lower.endswith(".mp4"): + mime_type = "video/mp4" + elif video_path_lower.endswith(".webm"): + mime_type = "video/webm" + elif video_path_lower.endswith(".mov"): + mime_type = "video/quicktime" + elif video_path_lower.endswith(".avi"): + mime_type = "video/x-msvideo" + elif video_path_lower.endswith(".mkv"): + mime_type = "video/x-matroska" + else: + mime_type = "video/mp4" + + with open(video_path, "rb") as f: + video_bytes = f.read() + video_b64 = base64.b64encode(video_bytes).decode("utf-8") + return f"data:{mime_type};base64,{video_b64}" + + +def process_audio_file( + audio_file: Any | None, +) -> tuple[np.ndarray, int] | None: + """Normalize Gradio audio input to (np.ndarray, sample_rate).""" + if audio_file is None: + return None + + sample_rate: int | None = None + audio_np: np.ndarray | None = None + + def _load_from_path(path_str: str) -> tuple[np.ndarray, int] | None: + if not path_str: + return None + path = Path(path_str) + if not path.exists(): + return None + data, sr = sf.read(path) + if data.ndim > 1: + data = data[:, 0] + return data.astype(np.float32), int(sr) + + if isinstance(audio_file, tuple): + if len(audio_file) == 2: + first, second = audio_file + # Case 1: (sample_rate, np.ndarray) + if isinstance(first, (int, float)) and isinstance(second, np.ndarray): + sample_rate = int(first) + audio_np = second + # Case 2: (filepath, (sample_rate, np.ndarray or list)) + elif isinstance(first, str): + if isinstance(second, tuple) and len(second) == 2: + sr_candidate, data_candidate = second + if isinstance(sr_candidate, (int, float)) and isinstance(data_candidate, np.ndarray): + sample_rate = int(sr_candidate) + audio_np = data_candidate + if audio_np is None: + loaded = _load_from_path(first) + if loaded is not None: + audio_np, sample_rate = loaded + # Case 3: (None, (sample_rate, np.ndarray)) + elif first is None and isinstance(second, tuple) and len(second) == 2: + sr_candidate, data_candidate = second + if isinstance(sr_candidate, (int, float)) and isinstance(data_candidate, np.ndarray): + sample_rate = int(sr_candidate) + audio_np = data_candidate + elif len(audio_file) == 1 and isinstance(audio_file[0], str): + loaded = _load_from_path(audio_file[0]) + if loaded is not None: + audio_np, sample_rate = loaded + elif isinstance(audio_file, str): + loaded = _load_from_path(audio_file) + if loaded is not None: + audio_np, sample_rate = loaded + + if audio_np is None or sample_rate is None: + return None + + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + + return audio_np.astype(np.float32), sample_rate + + +def process_image_file(image_file: Image.Image | None) -> Image.Image | None: + """Process image file from Gradio input. + + Returns: + PIL Image in RGB mode or None if no image provided. + """ + if image_file is None: + return None + # Convert to RGB if needed + if image_file.mode != "RGB": + image_file = image_file.convert("RGB") + return image_file + + +def run_inference_api( + client: OpenAI, + model: str, + sampling_params_dict: list[dict], + user_prompt: str, + audio_file: tuple[str, tuple[int, np.ndarray]] | None = None, + image_file: Image.Image | None = None, + video_file: str | None = None, + use_audio_in_video: bool = False, + output_modalities: str | None = None, + stream: bool = False, +): + """Run inference using OpenAI API client with multimodal support.""" + if not user_prompt.strip() and not audio_file and not image_file and not video_file: + yield "Please provide at least a text prompt or multimodal input.", None + return + + try: + # Build message content list + content_list = [] + + # Process audio + audio_data = process_audio_file(audio_file) + if audio_data is not None: + audio_url = audio_to_base64_data_url(audio_data) + content_list.append( + { + "type": "audio_url", + "audio_url": {"url": audio_url}, + } + ) + + # Process image + if image_file is not None: + image_data = process_image_file(image_file) + if image_data is not None: + image_url = image_to_base64_data_url(image_data) + content_list.append( + { + "type": "image_url", + "image_url": {"url": image_url}, + } + ) + + # Process video + mm_processor_kwargs = {} + if video_file is not None: + video_url = video_to_base64_data_url(video_file) + video_content = { + "type": "video_url", + "video_url": {"url": video_url}, + } + if use_audio_in_video: + video_content["video_url"]["num_frames"] = 32 # Default max frames + mm_processor_kwargs["use_audio_in_video"] = True + content_list.append(video_content) + + # Add text prompt + if user_prompt.strip(): + content_list.append( + { + "type": "text", + "text": user_prompt, + } + ) + + # Build messages + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + }, + { + "role": "user", + "content": content_list, + }, + ] + + # Build extra_body + extra_body = { + "sampling_params_list": sampling_params_dict, + } + if mm_processor_kwargs: + extra_body["mm_processor_kwargs"] = mm_processor_kwargs + + # Parse output modalities + if output_modalities and output_modalities.strip(): + output_modalities_list = [m.strip() for m in output_modalities.split(",")] + else: + output_modalities_list = None + + # Call API + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + modalities=output_modalities_list, + extra_body=extra_body, + stream=stream, + ) + + if not stream: + # Non-streaming mode: extract outputs and yield once + text_outputs: list[str] = [] + audio_output = None + + for choice in chat_completion.choices: + if choice.message.content: + text_outputs.append(choice.message.content) + if choice.message.audio: + # Decode base64 audio + audio_data = base64.b64decode(choice.message.audio.data) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + + text_response = "\n\n".join(text_outputs) if text_outputs else "No text output." + yield text_response, audio_output + else: + # Streaming mode: yield incremental updates + text_content = "" + audio_output = None + + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + # Handle audio modality + if getattr(chunk, "modality", None) == "audio" and content: + try: + # Decode base64 audio + audio_data = base64.b64decode(content) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + # Yield current text and audio + yield text_content if text_content else "", audio_output + except Exception: # pylint: disable=broad-except + # If audio processing fails, just yield text + yield text_content if text_content else "", None + + # Handle text modality + elif getattr(chunk, "modality", None) == "text": + if content: + text_content += content + # Yield updated text content (keep existing audio if any) + yield text_content, audio_output + + # Final yield with accumulated text and last audio (if any) + yield text_content if text_content else "No text output.", audio_output + + except Exception as exc: # pylint: disable=broad-except + error_msg = f"Inference failed: {exc}" + yield error_msg, None + + +def build_interface( + client: OpenAI, + model: str, + sampling_params_dict: list[dict], +): + """Build Gradio interface for API server mode.""" + + def run_inference( + user_prompt: str, + audio_file: tuple[str, tuple[int, np.ndarray]] | None, + image_file: Image.Image | None, + video_file: str | None, + use_audio_in_video: bool, + output_modalities: str | None = None, + stream: bool = False, + ): + # Always yield from the API function to maintain consistent generator behavior + yield from run_inference_api( + client, + model, + sampling_params_dict, + user_prompt, + audio_file, + image_file, + video_file, + use_audio_in_video, + output_modalities, + stream, + ) + + css = """ + .media-input-container { + display: flex; + gap: 10px; + } + .media-input-container > div { + flex: 1; + } + .media-input-container .image-input, + .media-input-container .audio-input { + height: 300px; + } + .media-input-container .video-column { + height: 300px; + display: flex; + flex-direction: column; + } + .media-input-container .video-input { + flex: 1; + min-height: 0; + } + #generate-btn button { + width: 100%; + } + """ + + with gr.Blocks(css=css) as demo: + gr.Markdown("# vLLM-Omni Online Serving Demo") + gr.Markdown(f"**Model:** {model} \n\n") + + with gr.Column(): + with gr.Row(): + input_box = gr.Textbox( + label="Text Prompt", + placeholder="For example: Describe what happens in the media inputs.", + lines=4, + scale=1, + ) + with gr.Row(elem_classes="media-input-container"): + image_input = gr.Image( + label="Image Input (optional)", + type="pil", + sources=["upload"], + scale=1, + elem_classes="image-input", + ) + with gr.Column(scale=1, elem_classes="video-column"): + video_input = gr.Video( + label="Video Input (optional)", + sources=["upload"], + elem_classes="video-input", + ) + use_audio_in_video_checkbox = gr.Checkbox( + label="Use audio from video", + value=False, + info="Extract the video's audio track when provided.", + ) + audio_input = gr.Audio( + label="Audio Input (optional)", + type="numpy", + sources=["upload", "microphone"], + scale=1, + elem_classes="audio-input", + ) + + with gr.Row(): + output_modalities = gr.Textbox( + label="Output Modalities", + value=None, + placeholder="For example: text, image, video. Use comma to separate multiple modalities.", + lines=1, + scale=2, + ) + stream_checkbox = gr.Checkbox( + label="Stream output", + value=False, + info="Enable streaming to see output as it's generated.", + scale=1, + ) + + with gr.Row(): + generate_btn = gr.Button( + "Generate", + variant="primary", + size="lg", + elem_id="generate-btn", + ) + + with gr.Row(): + text_output = gr.Textbox(label="Text Output", lines=10, scale=2) + audio_output = gr.Audio(label="Audio Output", interactive=False, scale=1) + + generate_btn.click( + fn=run_inference, + inputs=[ + input_box, + audio_input, + image_input, + video_input, + use_audio_in_video_checkbox, + output_modalities, + stream_checkbox, + ], + outputs=[text_output, audio_output], + ) + demo.queue() + return demo + + +def main(): + args = parse_args() + + model_name = "/".join(args.model.split("/")[-2:]) + assert model_name in SUPPORTED_MODELS, ( + f"Unsupported model '{model_name}'. Supported models: {SUPPORTED_MODELS.keys()}" + ) + + # Initialize OpenAI client + print(f"Connecting to API server at: {args.api_base}") + client = OpenAI( + api_key="EMPTY", + base_url=args.api_base, + ) + print("✓ Connected to API server") + + # Build sampling params + sampling_params_dict = build_sampling_params_dict(SEED, model_name) + + demo = build_interface( + client, + args.model, + sampling_params_dict, + ) + try: + demo.launch( + server_name=args.ip, + server_port=args.port, + share=args.share, + ) + except KeyboardInterrupt: + print("\nShutting down...") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..a25e97ebf95bc1a84d1d15ba55d021c560e5c1a6 --- /dev/null +++ b/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py @@ -0,0 +1,451 @@ +import base64 +import os + +import requests +from openai import OpenAI +from vllm.assets.audio import AudioAsset +from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8091/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +SEED = 42 + + +def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode("utf-8") + + return result + + +def encode_base64_content_from_file(file_path: str) -> str: + """Encode a local file to base64 format.""" + with open(file_path, "rb") as f: + content = f.read() + result = base64.b64encode(content).decode("utf-8") + return result + + +def get_video_url_from_path(video_path: str | None) -> str: + """Convert a video path (local file or URL) to a video URL format for the API. + + If video_path is None or empty, returns the default URL. + If video_path is a local file path, encodes it to base64 data URL. + If video_path is a URL, returns it as-is. + """ + if not video_path: + # Default video URL + return "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + + # Check if it's a URL (starts with http:// or https://) + if video_path.startswith(("http://", "https://")): + return video_path + + # Otherwise, treat it as a local file path + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + + # Detect video MIME type from file extension + video_path_lower = video_path.lower() + if video_path_lower.endswith(".mp4"): + mime_type = "video/mp4" + elif video_path_lower.endswith(".webm"): + mime_type = "video/webm" + elif video_path_lower.endswith(".mov"): + mime_type = "video/quicktime" + elif video_path_lower.endswith(".avi"): + mime_type = "video/x-msvideo" + elif video_path_lower.endswith(".mkv"): + mime_type = "video/x-matroska" + else: + # Default to mp4 if extension is unknown + mime_type = "video/mp4" + + video_base64 = encode_base64_content_from_file(video_path) + return f"data:{mime_type};base64,{video_base64}" + + +def get_image_url_from_path(image_path: str | None) -> str: + """Convert an image path (local file or URL) to an image URL format for the API. + + If image_path is None or empty, returns the default URL. + If image_path is a local file path, encodes it to base64 data URL. + If image_path is a URL, returns it as-is. + """ + if not image_path: + # Default image URL + return "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" + + # Check if it's a URL (starts with http:// or https://) + if image_path.startswith(("http://", "https://")): + return image_path + + # Otherwise, treat it as a local file path + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + + # Detect image MIME type from file extension + image_path_lower = image_path.lower() + if image_path_lower.endswith((".jpg", ".jpeg")): + mime_type = "image/jpeg" + elif image_path_lower.endswith(".png"): + mime_type = "image/png" + elif image_path_lower.endswith(".gif"): + mime_type = "image/gif" + elif image_path_lower.endswith(".webp"): + mime_type = "image/webp" + else: + # Default to jpeg if extension is unknown + mime_type = "image/jpeg" + + image_base64 = encode_base64_content_from_file(image_path) + return f"data:{mime_type};base64,{image_base64}" + + +def get_audio_url_from_path(audio_path: str | None) -> str: + """Convert an audio path (local file or URL) to an audio URL format for the API. + + If audio_path is None or empty, returns the default URL. + If audio_path is a local file path, encodes it to base64 data URL. + If audio_path is a URL, returns it as-is. + """ + if not audio_path: + # Default audio URL + return AudioAsset("mary_had_lamb").url + + # Check if it's a URL (starts with http:// or https://) + if audio_path.startswith(("http://", "https://")): + return audio_path + + # Otherwise, treat it as a local file path + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + # Detect audio MIME type from file extension + audio_path_lower = audio_path.lower() + if audio_path_lower.endswith((".mp3", ".mpeg")): + mime_type = "audio/mpeg" + elif audio_path_lower.endswith(".wav"): + mime_type = "audio/wav" + elif audio_path_lower.endswith(".ogg"): + mime_type = "audio/ogg" + elif audio_path_lower.endswith(".flac"): + mime_type = "audio/flac" + elif audio_path_lower.endswith(".m4a"): + mime_type = "audio/mp4" + else: + # Default to wav if extension is unknown + mime_type = "audio/wav" + + audio_base64 = encode_base64_content_from_file(audio_path) + return f"data:{mime_type};base64,{audio_base64}" + + +def get_system_prompt(): + return { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + } + + +def get_text_query(custom_prompt: str | None = None): + question = ( + custom_prompt or "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + ) + prompt = { + "role": "user", + "content": [ + { + "type": "text", + "text": f"{question}", + } + ], + } + return prompt + + +def get_mixed_modalities_query( + video_path: str | None = None, + image_path: str | None = None, + audio_path: str | None = None, + custom_prompt: str | None = None, +): + question = ( + custom_prompt or "What is recited in the audio? What is the content of this image? Why is this video funny?" + ) + video_url = get_video_url_from_path(video_path) + image_url = get_image_url_from_path(image_path) + audio_url = get_audio_url_from_path(audio_path) + prompt = { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": audio_url}, + }, + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "video_url", + "video_url": {"url": video_url}, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + + return prompt + + +def get_use_audio_in_video_query(video_path: str | None = None, custom_prompt: str | None = None): + question = custom_prompt or "Describe the content of the video, then convert what the baby say into text." + video_url = get_video_url_from_path(video_path) + + prompt = { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": video_url, + "num_frames": 16, + }, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + + return prompt + + +def get_multi_audios_query(audio_path: str | None = None, custom_prompt: str | None = None): + question = custom_prompt or "Are these two audio clips the same?" + audio_url = get_audio_url_from_path(audio_path) + prompt = { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": audio_url}, + }, + { + "type": "audio_url", + "audio_url": {"url": AudioAsset("winning_call").url}, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + return prompt + + +query_map = { + "mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "multi_audios": get_multi_audios_query, + "text": get_text_query, +} + + +def run_multimodal_generation(args) -> None: + model_name = "Qwen/Qwen2.5-Omni-7B" + thinker_sampling_params = { + "temperature": 0.0, # Deterministic - no randomness + "top_p": 1.0, # Disable nucleus sampling + "top_k": -1, # Disable top-k sampling + "max_tokens": 2048, + "seed": SEED, # Fixed seed for sampling + "detokenize": True, + "repetition_penalty": 1.1, + } + talker_sampling_params = { + "temperature": 0.9, + "top_p": 0.8, + "top_k": 40, + "max_tokens": 2048, + "seed": SEED, # Fixed seed for sampling + "detokenize": True, + "repetition_penalty": 1.05, + "stop_token_ids": [8294], + } + code2wav_sampling_params = { + "temperature": 0.0, # Deterministic - no randomness + "top_p": 1.0, # Disable nucleus sampling + "top_k": -1, # Disable top-k sampling + "max_tokens": 2048, + "seed": SEED, # Fixed seed for sampling + "detokenize": True, + "repetition_penalty": 1.1, + } + + sampling_params_list = [ + thinker_sampling_params, + talker_sampling_params, + code2wav_sampling_params, + ] + + # Get paths and custom prompt from args + video_path = getattr(args, "video_path", None) + image_path = getattr(args, "image_path", None) + audio_path = getattr(args, "audio_path", None) + custom_prompt = getattr(args, "prompt", None) + + # Get the query function and call it with appropriate parameters + query_func = query_map[args.query_type] + if args.query_type == "mixed_modalities": + prompt = query_func( + video_path=video_path, image_path=image_path, audio_path=audio_path, custom_prompt=custom_prompt + ) + elif args.query_type == "use_audio_in_video": + prompt = query_func(video_path=video_path, custom_prompt=custom_prompt) + elif args.query_type == "multi_audios": + prompt = query_func(audio_path=audio_path, custom_prompt=custom_prompt) + elif args.query_type == "text": + prompt = query_func(custom_prompt=custom_prompt) + else: + prompt = query_func() + + extra_body = { + "sampling_params_list": sampling_params_list # Optional, it has a default setting in stage_configs of the corresponding model. + } + + if args.query_type == "use_audio_in_video": + extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True} + + if args.modalities is not None: + output_modalities = args.modalities.split(",") + else: + output_modalities = None + + chat_completion = client.chat.completions.create( + messages=[ + get_system_prompt(), + prompt, + ], + model=model_name, + modalities=output_modalities, + extra_body=extra_body, + stream=args.stream, + ) + + count = 0 + if not args.stream: + for choice in chat_completion.choices: + if choice.message.audio: + audio_data = base64.b64decode(choice.message.audio.data) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"Audio saved to {audio_file_path}") + count += 1 + elif choice.message.content: + print("Chat completion output from text:", choice.message.content) + else: + printed_content = False + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + if getattr(chunk, "modality", None) == "audio" and content: + audio_data = base64.b64decode(content) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"\nAudio saved to {audio_file_path}") + count += 1 + + elif getattr(chunk, "modality", None) == "text": + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + print(content, end="", flush=True) + + +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--video-path", + "-v", + type=str, + default=None, + help="Path to local video file or URL. If not provided and query-type uses video, uses default video URL.", + ) + parser.add_argument( + "--image-path", + "-i", + type=str, + default=None, + help="Path to local image file or URL. If not provided and query-type uses image, uses default image URL.", + ) + parser.add_argument( + "--audio-path", + "-a", + type=str, + default=None, + help="Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL.", + ) + parser.add_argument( + "--prompt", + "-p", + type=str, + default=None, + help="Custom text prompt/question to use instead of the default prompt for the selected query type.", + ) + parser.add_argument( + "--modalities", + type=str, + default=None, + help="Output modalities to use for the prompts.", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream the response.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + run_multimodal_generation(args) diff --git a/examples/online_serving/qwen2_5_omni/run_curl_multimodal_generation.sh b/examples/online_serving/qwen2_5_omni/run_curl_multimodal_generation.sh new file mode 100644 index 0000000000000000000000000000000000000000..c5a265b629ea479202ef69bb4c051e5fe4362e05 --- /dev/null +++ b/examples/online_serving/qwen2_5_omni/run_curl_multimodal_generation.sh @@ -0,0 +1,192 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Default query type +QUERY_TYPE="${1:-mixed_modalities}" + +# Default modalities argument +MODALITIES="${2:-null}" + +# Validate query type +if [[ ! "$QUERY_TYPE" =~ ^(mixed_modalities|use_audio_in_video|multi_audios|text)$ ]]; then + echo "Error: Invalid query type '$QUERY_TYPE'" + echo "Usage: $0 [mixed_modalities|use_audio_in_video|multi_audios|text] [modalities]" + echo " mixed_modalities: Audio + Image + Video + Text query" + echo " use_audio_in_video: Video + Text query (with audio extraction from video)" + echo " multi_audios: Two audio clips + Text query" + echo " text: Text query" + echo " modalities: Modalities parameter (default: null)" + exit 1 +fi + +SEED=42 + +thinker_sampling_params='{ + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 2048, + "seed": 42, + "detokenize": true, + "repetition_penalty": 1.1 +}' + +talker_sampling_params='{ + "temperature": 0.9, + "top_p": 0.8, + "top_k": 40, + "max_tokens": 2048, + "seed": 42, + "detokenize": true, + "repetition_penalty": 1.05, + "stop_token_ids": [8294] +}' + +code2wav_sampling_params='{ + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 2048, + "seed": 42, + "detokenize": true, + "repetition_penalty": 1.1 +}' +# Above is optional, it has a default setting in stage_configs of the corresponding model. + +# Define URLs for assets +MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg" +WINNING_CALL_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/winning_call.ogg" +CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" +SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + +# Build user content and extra fields based on query type +case "$QUERY_TYPE" in + text) + user_content='[ + { + "type": "text", + "text": "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + mixed_modalities) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "What is recited in the audio? What is the content of this image? Why is this video funny?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + use_audio_in_video) + user_content='[ + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "Describe the content of the video, then convert what the baby say into text." + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs='{ + "use_audio_in_video": true + }' + ;; + multi_audios) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "audio_url", + "audio_url": { + "url": "'"$WINNING_CALL_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "Are these two audio clips the same?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; +esac + +echo "Running query type: $QUERY_TYPE" +echo "" + + +output=$(curl -sS -X POST http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @- <<EOF +{ + "model": "Qwen/Qwen2.5-Omni-7B", + "sampling_params_list": $sampling_params_list, + "mm_processor_kwargs": $mm_processor_kwargs, + "modalities": $MODALITIES, + "messages": [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + } + ] + }, + { + "role": "user", + "content": $user_content + } + ] +} +EOF + ) + +# Here it only shows the text content of the first choice. Audio content has many binaries, so it's not displayed here. +echo "Output of request: $(echo "$output" | jq '.choices[0].message.content')" diff --git a/examples/online_serving/qwen2_5_omni/run_gradio_demo.sh b/examples/online_serving/qwen2_5_omni/run_gradio_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..296cc7fa52369b785e69ff29003d9bbb969616c2 --- /dev/null +++ b/examples/online_serving/qwen2_5_omni/run_gradio_demo.sh @@ -0,0 +1,212 @@ +#!/bin/bash +# Convenience script to launch both vLLM server and Gradio demo for Qwen2.5-Omni +# +# Usage: +# ./run_gradio_demo.sh [OPTIONS] +# +# Example: +# ./run_gradio_demo.sh --model Qwen/Qwen2.5-Omni-7B --server-port 8091 --gradio-port 7861 + +set -e + +# Default values +MODEL="Qwen/Qwen2.5-Omni-7B" +SERVER_PORT=8091 +GRADIO_PORT=7861 +STAGE_CONFIGS_PATH="" +SERVER_HOST="0.0.0.0" +GRADIO_IP="127.0.0.1" +GRADIO_SHARE=false + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --server-port) + SERVER_PORT="$2" + shift 2 + ;; + --gradio-port) + GRADIO_PORT="$2" + shift 2 + ;; + --stage-configs-path) + STAGE_CONFIGS_PATH="$2" + shift 2 + ;; + --server-host) + SERVER_HOST="$2" + shift 2 + ;; + --gradio-ip) + GRADIO_IP="$2" + shift 2 + ;; + --share) + GRADIO_SHARE=true + shift + ;; + --help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --model MODEL Model name/path (default: Qwen/Qwen2.5-Omni-7B)" + echo " --server-port PORT Port for vLLM server (default: 8091)" + echo " --gradio-port PORT Port for Gradio demo (default: 7861)" + echo " --stage-configs-path PATH Path to custom stage configs YAML file (optional)" + echo " --server-host HOST Host for vLLM server (default: 0.0.0.0)" + echo " --gradio-ip IP IP for Gradio demo (default: 127.0.0.1)" + echo " --share Share Gradio demo publicly" + echo " --help Show this help message" + echo "" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Get the directory where this script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +API_BASE="http://localhost:${SERVER_PORT}/v1" +HEALTH_URL="http://localhost:${SERVER_PORT}/health" + +echo "==========================================" +echo "Starting vLLM-Omni Gradio Demo" +echo "==========================================" +echo "Model: $MODEL" +echo "Server: http://${SERVER_HOST}:${SERVER_PORT}" +echo "Gradio: http://${GRADIO_IP}:${GRADIO_PORT}" +echo "==========================================" + +# Build vLLM server command +SERVER_CMD=("vllm" "serve" "$MODEL" "--omni" "--port" "$SERVER_PORT" "--host" "$SERVER_HOST") +if [ -n "$STAGE_CONFIGS_PATH" ]; then + SERVER_CMD+=("--stage-configs-path" "$STAGE_CONFIGS_PATH") +fi + +# Function to cleanup on exit +cleanup() { + echo "" + echo "Shutting down..." + if [ -n "$SERVER_PID" ]; then + echo "Stopping vLLM server (PID: $SERVER_PID)..." + kill "$SERVER_PID" 2>/dev/null || true + wait "$SERVER_PID" 2>/dev/null || true + fi + if [ -n "$GRADIO_PID" ]; then + echo "Stopping Gradio demo (PID: $GRADIO_PID)..." + kill "$GRADIO_PID" 2>/dev/null || true + wait "$GRADIO_PID" 2>/dev/null || true + fi + echo "Cleanup complete" + exit 0 +} + +# Set up signal handlers +trap cleanup SIGINT SIGTERM + +# Start vLLM server with output shown in real-time and saved to log +echo "" +echo "Starting vLLM server..." +LOG_FILE="/tmp/vllm_server_${SERVER_PORT}.log" +"${SERVER_CMD[@]}" 2>&1 | tee "$LOG_FILE" & +SERVER_PID=$! + +# Start a background process to monitor the log for startup completion +STARTUP_COMPLETE=false +TAIL_PID="" + +# Function to cleanup tail process +cleanup_tail() { + if [ -n "$TAIL_PID" ]; then + kill "$TAIL_PID" 2>/dev/null || true + wait "$TAIL_PID" 2>/dev/null || true + fi +} + +# Wait for server to be ready by checking log output +echo "" +echo "Waiting for vLLM server to be ready (checking for 'Application startup complete' message)..." +echo "" + +# Monitor log file for startup completion message +MAX_WAIT=300 # 5 minutes timeout as fallback +ELAPSED=0 + +# Use a temporary file to track startup completion +STARTUP_FLAG="/tmp/vllm_startup_flag_${SERVER_PORT}.tmp" +rm -f "$STARTUP_FLAG" + +# Start monitoring in background +( + tail -f "$LOG_FILE" 2>/dev/null | grep -m 1 "Application startup complete" > /dev/null && touch "$STARTUP_FLAG" +) & +TAIL_PID=$! + +while [ $ELAPSED -lt $MAX_WAIT ]; do + # Check if startup flag file exists (startup complete) + if [ -f "$STARTUP_FLAG" ]; then + cleanup_tail + echo "" + echo "✓ vLLM server is ready!" + STARTUP_COMPLETE=true + break + fi + + # Check if server process is still running + if ! kill -0 "$SERVER_PID" 2>/dev/null; then + cleanup_tail + echo "" + echo "Error: vLLM server failed to start (process terminated)" + wait "$SERVER_PID" 2>/dev/null || true + exit 1 + fi + + sleep 1 + ELAPSED=$((ELAPSED + 1)) +done + +cleanup_tail +rm -f "$STARTUP_FLAG" + +if [ "$STARTUP_COMPLETE" != "true" ]; then + echo "" + echo "Error: vLLM server did not complete startup within ${MAX_WAIT} seconds" + kill "$SERVER_PID" 2>/dev/null || true + exit 1 +fi + +# Start Gradio demo +echo "" +echo "Starting Gradio demo..." +cd "$SCRIPT_DIR" +GRADIO_CMD=("python" "gradio_demo.py" "--model" "$MODEL" "--api-base" "$API_BASE" "--ip" "$GRADIO_IP" "--port" "$GRADIO_PORT") +if [ "$GRADIO_SHARE" = true ]; then + GRADIO_CMD+=("--share") +fi + +"${GRADIO_CMD[@]}" > /tmp/gradio_demo.log 2>&1 & +GRADIO_PID=$! + +echo "" +echo "==========================================" +echo "Both services are running!" +echo "==========================================" +echo "vLLM Server: http://${SERVER_HOST}:${SERVER_PORT}" +echo "Gradio Demo: http://${GRADIO_IP}:${GRADIO_PORT}" +echo "" +echo "Press Ctrl+C to stop both services" +echo "==========================================" +echo "" + +# Wait for either process to exit +wait $SERVER_PID $GRADIO_PID || true + +cleanup diff --git a/examples/online_serving/qwen3_omni/README.md b/examples/online_serving/qwen3_omni/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9db1060b756efb9681af5dd0fab67d1c3703b7d7 --- /dev/null +++ b/examples/online_serving/qwen3_omni/README.md @@ -0,0 +1,221 @@ +# Qwen3-Omni + +## 🛠️ Installation + +Please refer to [README.md](../../../README.md) + +## Run examples (Qwen3-Omni) + +### Launch the Server + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + +If you want to open async chunking for qwen3-omni, launch the server with command below + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +``` + +If you have custom stage configs file, launch the server with command below +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +Get into the example folder +```bash +cd examples/online_serving/qwen3_omni +``` + +#### Send request via python + +```bash +python openai_chat_completion_client_for_multimodal_generation.py --query-type use_image +``` + +The Python client supports the following command-line arguments: + +- `--query-type` (or `-q`): Query type (default: `use_video`). Options: `text`, `use_audio`, `use_image`, `use_video` +- `--model` (or `-m`): Model name/path (default: `Qwen/Qwen3-Omni-30B-A3B-Instruct`) +- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type is `use_video`, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4` +- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type is `use_image`, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` +- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type is `use_audio`, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` +- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` + + +For example, to use a local video file with custom prompt: + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_video \ + --video-path /path/to/your/video.mp4 \ + --prompt "What are the main activities shown in this video?" +``` + +#### Send request via curl + +```bash +bash run_curl_multimodal_generation.sh use_image +``` + + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` + +## Modality control +You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. + +### Supported modalities + +| Modalities | Output | +|------------|--------| +| `["text"]` | Text only | +| `["audio"]` | Text + Audio | +| `["text", "audio"]` | Text + Audio | +| Not specified | Text + Audio (default) | + +### Using curl + +#### Text only + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["text"] + }' +``` + +#### Text + Audio + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["audio"] + }' +``` + +### Using Python client + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --modalities text +``` + +### Using OpenAI Python SDK + +#### Text only + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen3-Omni-30B-A3B-Instruct", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["text"] +) +print(response.choices[0].message.content) +``` + +#### Text + Audio + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen3-Omni-30B-A3B-Instruct", + messages=[{"role": "user", "content": "Describe vLLM in brief."}], + modalities=["audio"] +) +# Response contains two choices: one with text, one with audio +print(response.choices[0].message.content) # Text response +print(response.choices[1].message.audio) # Audio response +``` + +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --stream +``` + +## Run Local Web UI Demo + +This Web UI demo allows users to interact with the model through a web browser. + +### Running Gradio Demo + +The Gradio demo connects to a vLLM API server. You have two options: + +#### Option 1: One-step Launch Script (Recommended) + +The convenience script launches both the vLLM server and Gradio demo together: + +```bash +./run_gradio_demo.sh --model Qwen/Qwen3-Omni-30B-A3B-Instruct --server-port 8091 --gradio-port 7861 +``` + +This script will: +1. Start the vLLM server in the background +2. Wait for the server to be ready +3. Launch the Gradio demo +4. Handle cleanup when you press Ctrl+C + +The script supports the following arguments: +- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct) +- `--server-port`: Port for vLLM server (default: 8091) +- `--gradio-port`: Port for Gradio demo (default: 7861) +- `--stage-configs-path`: Path to custom stage configs YAML file (optional) +- `--server-host`: Host for vLLM server (default: 0.0.0.0) +- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1) +- `--share`: Share Gradio demo publicly (creates a public link) + +#### Option 2: Manual Launch (Two-Step Process) + +**Step 1: Launch the vLLM API server** + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + +If you have custom stage configs file: +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +**Step 2: Run the Gradio demo** + +In a separate terminal: + +```bash +python gradio_demo.py --model Qwen/Qwen3-Omni-30B-A3B-Instruct --api-base http://localhost:8091/v1 --port 7861 +``` + +Then open `http://localhost:7861/` on your local browser to interact with the web UI. + +The gradio script supports the following arguments: + +- `--model`: Model name/path (should match the server model) +- `--api-base`: Base URL for the vLLM API server (default: http://localhost:8091/v1) +- `--ip`: Host/IP for Gradio server (default: 127.0.0.1) +- `--port`: Port for Gradio server (default: 7861) +- `--share`: Share the Gradio demo publicly (creates a public link) diff --git a/examples/online_serving/qwen3_omni/gradio_demo.py b/examples/online_serving/qwen3_omni/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..76c4e311a36c5a860a8fd99ec92ac2e0b47e149c --- /dev/null +++ b/examples/online_serving/qwen3_omni/gradio_demo.py @@ -0,0 +1,583 @@ +import argparse +import base64 +import io +import os +import random +from pathlib import Path +from typing import Any + +import gradio as gr +import numpy as np +import soundfile as sf +import torch +from openai import OpenAI +from PIL import Image + +SEED = 42 + +SUPPORTED_MODELS: dict[str, dict[str, Any]] = { + "Qwen/Qwen3-Omni-30B-A3B-Instruct": { + "sampling_params": { + "thinker": { + "temperature": 0.4, + "top_p": 0.9, + "top_k": 1, + "max_tokens": 16384, + "detokenize": True, + "repetition_penalty": 1.05, + "stop_token_ids": [151645], + "seed": SEED, + }, + "talker": { + "temperature": 0.9, + "top_k": 50, + "max_tokens": 4096, + "seed": SEED, + "detokenize": False, + "repetition_penalty": 1.05, + "stop_token_ids": [2150], + }, + "code2wav": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 4096 * 16, + "seed": SEED, + "detokenize": True, + "repetition_penalty": 1.1, + }, + }, + }, +} +# Ensure deterministic behavior across runs. +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.cuda.manual_seed_all(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +os.environ["PYTHONHASHSEED"] = str(SEED) +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +def parse_args(): + parser = argparse.ArgumentParser(description="Gradio demo for Qwen3-Omni online inference.") + parser.add_argument( + "--model", + default="Qwen/Qwen3-Omni-30B-A3B-Instruct", + help="Model name/path (should match the server model).", + ) + parser.add_argument( + "--api-base", + default="http://localhost:8091/v1", + help="Base URL for the vLLM API server.", + ) + parser.add_argument( + "--ip", + default="127.0.0.1", + help="Host/IP for gradio `launch`.", + ) + parser.add_argument("--port", type=int, default=7861, help="Port for gradio `launch`.") + parser.add_argument("--share", action="store_true", help="Share the Gradio demo publicly.") + return parser.parse_args() + + +def build_sampling_params_dict(seed: int, model_key: str) -> list[dict]: + """Build sampling params as dict for HTTP API mode.""" + model_conf = SUPPORTED_MODELS.get(model_key) + if model_conf is None: + raise ValueError(f"Unsupported model '{model_key}'") + + sampling_templates: dict[str, dict[str, Any]] = model_conf["sampling_params"] + sampling_params: list[dict] = [] + for stage_name, template in sampling_templates.items(): + params = dict(template) + params["seed"] = seed + sampling_params.append(params) + return sampling_params + + +def image_to_base64_data_url(image: Image.Image) -> str: + """Convert PIL Image to base64 data URL.""" + buffered = io.BytesIO() + # Convert to RGB if needed + if image.mode != "RGB": + image = image.convert("RGB") + image.save(buffered, format="JPEG") + img_bytes = buffered.getvalue() + img_b64 = base64.b64encode(img_bytes).decode("utf-8") + return f"data:image/jpeg;base64,{img_b64}" + + +def audio_to_base64_data_url(audio_data: tuple[np.ndarray, int]) -> str: + """Convert audio (numpy array, sample_rate) to base64 data URL.""" + audio_np, sample_rate = audio_data + # Convert to int16 format for WAV + if audio_np.dtype != np.int16: + # Normalize to [-1, 1] range if needed + if audio_np.dtype == np.float32 or audio_np.dtype == np.float64: + audio_np = np.clip(audio_np, -1.0, 1.0) + audio_np = (audio_np * 32767).astype(np.int16) + else: + audio_np = audio_np.astype(np.int16) + + # Write to WAV bytes + buffered = io.BytesIO() + sf.write(buffered, audio_np, sample_rate, format="WAV") + wav_bytes = buffered.getvalue() + wav_b64 = base64.b64encode(wav_bytes).decode("utf-8") + return f"data:audio/wav;base64,{wav_b64}" + + +def video_to_base64_data_url(video_file: str) -> str: + """Convert video file to base64 data URL.""" + video_path = Path(video_file) + if not video_path.exists(): + raise FileNotFoundError(f"Video file not found: {video_file}") + + # Detect MIME type from extension + video_path_lower = str(video_path).lower() + if video_path_lower.endswith(".mp4"): + mime_type = "video/mp4" + elif video_path_lower.endswith(".webm"): + mime_type = "video/webm" + elif video_path_lower.endswith(".mov"): + mime_type = "video/quicktime" + elif video_path_lower.endswith(".avi"): + mime_type = "video/x-msvideo" + elif video_path_lower.endswith(".mkv"): + mime_type = "video/x-matroska" + else: + mime_type = "video/mp4" + + with open(video_path, "rb") as f: + video_bytes = f.read() + video_b64 = base64.b64encode(video_bytes).decode("utf-8") + return f"data:{mime_type};base64,{video_b64}" + + +def process_audio_file( + audio_file: Any | None, +) -> tuple[np.ndarray, int] | None: + """Normalize Gradio audio input to (np.ndarray, sample_rate).""" + if audio_file is None: + return None + + sample_rate: int | None = None + audio_np: np.ndarray | None = None + + def _load_from_path(path_str: str) -> tuple[np.ndarray, int] | None: + if not path_str: + return None + path = Path(path_str) + if not path.exists(): + return None + data, sr = sf.read(path) + if data.ndim > 1: + data = data[:, 0] + return data.astype(np.float32), int(sr) + + if isinstance(audio_file, tuple): + if len(audio_file) == 2: + first, second = audio_file + # Case 1: (sample_rate, np.ndarray) + if isinstance(first, (int, float)) and isinstance(second, np.ndarray): + sample_rate = int(first) + audio_np = second + # Case 2: (filepath, (sample_rate, np.ndarray or list)) + elif isinstance(first, str): + if isinstance(second, tuple) and len(second) == 2: + sr_candidate, data_candidate = second + if isinstance(sr_candidate, (int, float)) and isinstance(data_candidate, np.ndarray): + sample_rate = int(sr_candidate) + audio_np = data_candidate + if audio_np is None: + loaded = _load_from_path(first) + if loaded is not None: + audio_np, sample_rate = loaded + # Case 3: (None, (sample_rate, np.ndarray)) + elif first is None and isinstance(second, tuple) and len(second) == 2: + sr_candidate, data_candidate = second + if isinstance(sr_candidate, (int, float)) and isinstance(data_candidate, np.ndarray): + sample_rate = int(sr_candidate) + audio_np = data_candidate + elif len(audio_file) == 1 and isinstance(audio_file[0], str): + loaded = _load_from_path(audio_file[0]) + if loaded is not None: + audio_np, sample_rate = loaded + elif isinstance(audio_file, str): + loaded = _load_from_path(audio_file) + if loaded is not None: + audio_np, sample_rate = loaded + + if audio_np is None or sample_rate is None: + return None + + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + + return audio_np.astype(np.float32), sample_rate + + +def process_image_file(image_file: Image.Image | None) -> Image.Image | None: + """Process image file from Gradio input. + + Returns: + PIL Image in RGB mode or None if no image provided. + """ + if image_file is None: + return None + # Convert to RGB if needed + if image_file.mode != "RGB": + image_file = image_file.convert("RGB") + return image_file + + +def run_inference_api( + client: OpenAI, + model: str, + sampling_params_dict: list[dict], + user_prompt: str, + audio_file: tuple[str, tuple[int, np.ndarray]] | None = None, + image_file: Image.Image | None = None, + video_file: str | None = None, + use_audio_in_video: bool = False, + output_modalities: str | None = None, + stream: bool = False, +): + """Run inference using OpenAI API client with multimodal support.""" + if not user_prompt.strip() and not audio_file and not image_file and not video_file: + yield "Please provide at least a text prompt or multimodal input.", None + + try: + # Build message content list + content_list = [] + + # Process audio + audio_data = process_audio_file(audio_file) + if audio_data is not None: + audio_url = audio_to_base64_data_url(audio_data) + content_list.append( + { + "type": "audio_url", + "audio_url": {"url": audio_url}, + } + ) + + # Process image + if image_file is not None: + image_data = process_image_file(image_file) + if image_data is not None: + image_url = image_to_base64_data_url(image_data) + content_list.append( + { + "type": "image_url", + "image_url": {"url": image_url}, + } + ) + + # Process video + mm_processor_kwargs = {} + if video_file is not None: + video_url = video_to_base64_data_url(video_file) + video_content = { + "type": "video_url", + "video_url": {"url": video_url}, + } + if use_audio_in_video: + video_content["video_url"]["num_frames"] = 32 # Default max frames + mm_processor_kwargs["use_audio_in_video"] = True + content_list.append(video_content) + + # Add text prompt + if user_prompt.strip(): + content_list.append( + { + "type": "text", + "text": user_prompt, + } + ) + + # Build messages + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + }, + { + "role": "user", + "content": content_list, + }, + ] + + # Build extra_body + extra_body = { + "sampling_params_list": sampling_params_dict, + } + if mm_processor_kwargs: + extra_body["mm_processor_kwargs"] = mm_processor_kwargs + + # Parse output modalities + if output_modalities and output_modalities.strip(): + output_modalities_list = [m.strip() for m in output_modalities.split(",")] + else: + output_modalities_list = None + + # Call API + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + modalities=output_modalities_list, + extra_body=extra_body, + stream=stream, + ) + + if not stream: + # Non-streaming mode: extract outputs and yield once + text_outputs: list[str] = [] + audio_output = None + + for choice in chat_completion.choices: + if choice.message.content: + text_outputs.append(choice.message.content) + if choice.message.audio: + # Decode base64 audio + audio_data = base64.b64decode(choice.message.audio.data) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + + text_response = "\n\n".join(text_outputs) if text_outputs else "No text output." + yield text_response, audio_output + else: + # Streaming mode: yield incremental updates + text_content = "" + audio_output = None + + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + # Handle audio modality + if getattr(chunk, "modality", None) == "audio" and content: + try: + # Decode base64 audio + audio_data = base64.b64decode(content) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + # Yield current text and audio + yield text_content if text_content else "", audio_output + except Exception: # pylint: disable=broad-except + # If audio processing fails, just yield text + yield text_content if text_content else "", None + + # Handle text modality + elif getattr(chunk, "modality", None) == "text": + if content: + text_content += content + # Yield updated text content (keep existing audio if any) + yield text_content, audio_output + + # Final yield with accumulated text and last audio (if any) + yield text_content if text_content else "No text output.", audio_output + + except Exception as exc: # pylint: disable=broad-except + error_msg = f"Inference failed: {exc}" + yield error_msg, None + + +def build_interface( + client: OpenAI, + model: str, + sampling_params_dict: list[dict], +): + """Build Gradio interface for API server mode.""" + + def run_inference( + user_prompt: str, + audio_file: tuple[str, tuple[int, np.ndarray]] | None, + image_file: Image.Image | None, + video_file: str | None, + use_audio_in_video: bool, + output_modalities: str | None = None, + stream: bool = False, + ): + # Always yield from the API function to maintain consistent generator behavior + yield from run_inference_api( + client, + model, + sampling_params_dict, + user_prompt, + audio_file, + image_file, + video_file, + use_audio_in_video, + output_modalities, + stream, + ) + + css = """ + .media-input-container { + display: flex; + gap: 10px; + } + .media-input-container > div { + flex: 1; + } + .media-input-container .image-input, + .media-input-container .audio-input { + height: 300px; + } + .media-input-container .video-column { + height: 300px; + display: flex; + flex-direction: column; + } + .media-input-container .video-input { + flex: 1; + min-height: 0; + } + #generate-btn button { + width: 100%; + } + """ + + with gr.Blocks(css=css) as demo: + gr.Markdown("# vLLM-Omni Online Serving Demo") + gr.Markdown(f"**Model:** {model} \n\n") + + with gr.Column(): + with gr.Row(): + input_box = gr.Textbox( + label="Text Prompt", + placeholder="For example: Describe what happens in the media inputs.", + lines=4, + scale=1, + ) + with gr.Row(elem_classes="media-input-container"): + image_input = gr.Image( + label="Image Input (optional)", + type="pil", + sources=["upload"], + scale=1, + elem_classes="image-input", + ) + with gr.Column(scale=1, elem_classes="video-column"): + video_input = gr.Video( + label="Video Input (optional)", + sources=["upload"], + elem_classes="video-input", + ) + use_audio_in_video_checkbox = gr.Checkbox( + label="Use audio from video", + value=False, + info="Extract the video's audio track when provided.", + ) + audio_input = gr.Audio( + label="Audio Input (optional)", + type="numpy", + sources=["upload", "microphone"], + scale=1, + elem_classes="audio-input", + ) + + with gr.Row(): + output_modalities = gr.Textbox( + label="Output Modalities", + value=None, + placeholder="For example: text, image, video. Use comma to separate multiple modalities.", + lines=1, + scale=2, + ) + stream_checkbox = gr.Checkbox( + label="Stream output", + value=False, + info="Enable streaming to see output as it's generated.", + scale=1, + ) + + with gr.Row(): + generate_btn = gr.Button( + "Generate", + variant="primary", + size="lg", + elem_id="generate-btn", + ) + + with gr.Row(): + text_output = gr.Textbox(label="Text Output", lines=10, scale=2) + audio_output = gr.Audio(label="Audio Output", interactive=False, scale=1) + + generate_btn.click( + fn=run_inference, + inputs=[ + input_box, + audio_input, + image_input, + video_input, + use_audio_in_video_checkbox, + output_modalities, + stream_checkbox, + ], + outputs=[text_output, audio_output], + ) + demo.queue() + return demo + + +def main(): + args = parse_args() + + model_name = "/".join(args.model.split("/")[-2:]) + assert model_name in SUPPORTED_MODELS, ( + f"Unsupported model '{model_name}'. Supported models: {SUPPORTED_MODELS.keys()}" + ) + + # Initialize OpenAI client + print(f"Connecting to API server at: {args.api_base}") + client = OpenAI( + api_key="EMPTY", + base_url=args.api_base, + ) + print("✓ Connected to API server") + + # Build sampling params + sampling_params_dict = build_sampling_params_dict(SEED, model_name) + + demo = build_interface( + client, + args.model, + sampling_params_dict, + ) + try: + demo.launch( + server_name=args.ip, + server_port=args.port, + share=args.share, + ) + except KeyboardInterrupt: + print("\nShutting down...") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..304b27c4ea4ade36c99596d65968d73a30f1a0c4 --- /dev/null +++ b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py @@ -0,0 +1,549 @@ +import base64 +import concurrent.futures +import os +from typing import NamedTuple + +import requests +from openai import OpenAI +from vllm.assets.audio import AudioAsset +from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8091/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +SEED = 42 + + +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] + + +def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode("utf-8") + + return result + + +def encode_base64_content_from_file(file_path: str) -> str: + """Encode a local file to base64 format.""" + with open(file_path, "rb") as f: + content = f.read() + result = base64.b64encode(content).decode("utf-8") + return result + + +def get_video_url_from_path(video_path: str | None) -> str: + """Convert a video path (local file or URL) to a video URL format for the API. + + If video_path is None or empty, returns the default URL. + If video_path is a local file path, encodes it to base64 data URL. + If video_path is a URL, returns it as-is. + """ + if not video_path: + # Default video URL + return "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + + # Check if it's a URL (starts with http:// or https://) + if video_path.startswith(("http://", "https://")): + return video_path + + # Otherwise, treat it as a local file path + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + + # Detect video MIME type from file extension + video_path_lower = video_path.lower() + if video_path_lower.endswith(".mp4"): + mime_type = "video/mp4" + elif video_path_lower.endswith(".webm"): + mime_type = "video/webm" + elif video_path_lower.endswith(".mov"): + mime_type = "video/quicktime" + elif video_path_lower.endswith(".avi"): + mime_type = "video/x-msvideo" + elif video_path_lower.endswith(".mkv"): + mime_type = "video/x-matroska" + else: + # Default to mp4 if extension is unknown + mime_type = "video/mp4" + + video_base64 = encode_base64_content_from_file(video_path) + return f"data:{mime_type};base64,{video_base64}" + + +def get_image_url_from_path(image_path: str | None) -> str: + """Convert an image path (local file or URL) to an image URL format for the API. + + If image_path is None or empty, returns the default URL. + If image_path is a local file path, encodes it to base64 data URL. + If image_path is a URL, returns it as-is. + """ + if not image_path: + # Default image URL + return "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" + + # Check if it's a URL (starts with http:// or https://) + if image_path.startswith(("http://", "https://")): + return image_path + + # Otherwise, treat it as a local file path + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + + # Detect image MIME type from file extension + image_path_lower = image_path.lower() + if image_path_lower.endswith((".jpg", ".jpeg")): + mime_type = "image/jpeg" + elif image_path_lower.endswith(".png"): + mime_type = "image/png" + elif image_path_lower.endswith(".gif"): + mime_type = "image/gif" + elif image_path_lower.endswith(".webp"): + mime_type = "image/webp" + else: + # Default to jpeg if extension is unknown + mime_type = "image/jpeg" + + image_base64 = encode_base64_content_from_file(image_path) + return f"data:{mime_type};base64,{image_base64}" + + +def get_audio_url_from_path(audio_path: str | None) -> str: + """Convert an audio path (local file or URL) to an audio URL format for the API. + + If audio_path is None or empty, returns the default URL. + If audio_path is a local file path, encodes it to base64 data URL. + If audio_path is a URL, returns it as-is. + """ + if not audio_path: + # Default audio URL + return AudioAsset("mary_had_lamb").url + + # Check if it's a URL (starts with http:// or https://) + if audio_path.startswith(("http://", "https://")): + return audio_path + + # Otherwise, treat it as a local file path + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + # Detect audio MIME type from file extension + audio_path_lower = audio_path.lower() + if audio_path_lower.endswith((".mp3", ".mpeg")): + mime_type = "audio/mpeg" + elif audio_path_lower.endswith(".wav"): + mime_type = "audio/wav" + elif audio_path_lower.endswith(".ogg"): + mime_type = "audio/ogg" + elif audio_path_lower.endswith(".flac"): + mime_type = "audio/flac" + elif audio_path_lower.endswith(".m4a"): + mime_type = "audio/mp4" + else: + # Default to wav if extension is unknown + mime_type = "audio/wav" + + audio_base64 = encode_base64_content_from_file(audio_path) + return f"data:{mime_type};base64,{audio_base64}" + + +def get_system_prompt(): + return { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + } + + +def get_text_query(custom_prompt: str | None = None): + question = ( + custom_prompt or "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + ) + prompt = { + "role": "user", + "content": [ + { + "type": "text", + "text": f"{question}", + } + ], + } + return prompt + + +default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech." +) + + +def get_video_query(video_path: str | None = None, custom_prompt: str | None = None): + question = custom_prompt or "Why is this video funny?" + video_url = get_video_url_from_path(video_path) + prompt = { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": {"url": video_url}, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + return prompt + + +def get_image_query(image_path: str | None = None, custom_prompt: str | None = None): + question = custom_prompt or "What is the content of this image?" + image_url = get_image_url_from_path(image_path) + prompt = { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + return prompt + + +def get_audio_query(audio_path: str | None = None, custom_prompt: str | None = None): + question = custom_prompt or "What is the content of this audio?" + audio_url = get_audio_url_from_path(audio_path) + prompt = { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": audio_url}, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + return prompt + + +def get_mixed_modalities_query( + video_path: str | None = None, + image_path: str | None = None, + audio_path: str | None = None, + custom_prompt: str | None = None, +): + """ + Online-friendly multimodal user message: + - Uses URLs (or base64 data URLs) for audio / image / video. + - Returns the OpenAI-style message dict directly (not the offline QueryResult). + """ + question = ( + custom_prompt or "What is recited in the audio? What is the content of this image? Why is this video funny?" + ) + + audio_url = get_audio_url_from_path(audio_path) + image_url = get_image_url_from_path(image_path) + video_url = get_video_url_from_path(video_path) + + return { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": question}, + ], + } + + +def get_multi_audios_query(custom_prompt: str | None = None): + """ + Online-friendly two-audio comparison request. + - Encodes both audio clips as URLs (or data URLs). + - Returns the OpenAI-style message dict. + """ + question = custom_prompt or "Are these two audio clips the same?" + # Use default demo clips; you can point to your own via --audio-path if needed. + audio_url_1 = get_audio_url_from_path(AudioAsset("winning_call").url) + audio_url_2 = get_audio_url_from_path(AudioAsset("mary_had_lamb").url) + + return { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url_1}}, + {"type": "audio_url", "audio_url": {"url": audio_url_2}}, + {"type": "text", "text": question}, + ], + } + + +def get_use_audio_in_video_query( + video_path: str | None = None, + audio_path: str | None = None, + custom_prompt: str | None = None, +): + question = custom_prompt or ( + "Describe the content of the video in details, then convert what the baby say into text." + ) + video_url = get_video_url_from_path(video_path) + audio_url = get_audio_url_from_path(audio_path) + return { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": question}, + ], + } + + +query_map = { + "text": get_text_query, + "use_audio": get_audio_query, + "use_image": get_image_query, + "use_video": get_video_query, + "use_mixed_modalities": get_mixed_modalities_query, + "use_multi_audios": get_multi_audios_query, + "use_audio_in_video": get_use_audio_in_video_query, +} + + +def run_multimodal_generation(args) -> None: + model_name = args.model + thinker_sampling_params = { + "temperature": 0.4, # Deterministic + "top_p": 0.9, + "top_k": 1, + "max_tokens": 16384, + "repetition_penalty": 1.05, + "stop_token_ids": [151645], # Qwen EOS token <|im_end|> + "seed": SEED, + } + + # Sampling parameters for Talker stage (codec generation) + # Stop at codec EOS token + talker_sampling_params = { + "temperature": 0.9, + "top_k": 50, + "max_tokens": 4096, + "seed": SEED, + "detokenize": False, + "repetition_penalty": 1.05, + "stop_token_ids": [2150], # TALKER_CODEC_EOS_TOKEN_ID + } + + # # Sampling parameters for Code2Wav stage (audio generation) + code2wav_sampling_params = { + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 4096 * 16, + "seed": SEED, + "detokenize": True, + "repetition_penalty": 1.1, + } + + sampling_params_list = [ + thinker_sampling_params, + talker_sampling_params, + code2wav_sampling_params, + ] + + # Get paths and custom prompt from args + video_path = getattr(args, "video_path", None) + image_path = getattr(args, "image_path", None) + audio_path = getattr(args, "audio_path", None) + custom_prompt = getattr(args, "prompt", None) + + # Get the query function and call it with appropriate parameters + query_func = query_map[args.query_type] + if args.query_type == "use_video": + prompt = query_func(video_path=video_path, custom_prompt=custom_prompt) + elif args.query_type == "use_image": + prompt = query_func(image_path=image_path, custom_prompt=custom_prompt) + elif args.query_type == "use_audio": + prompt = query_func(audio_path=audio_path, custom_prompt=custom_prompt) + elif args.query_type == "text": + prompt = query_func(custom_prompt=custom_prompt) + elif args.query_type == "use_audio_in_video": + prompt = query_func( + video_path=video_path, + audio_path=audio_path, + custom_prompt=custom_prompt, + ) + else: + prompt = query_func() + + extra_body = { + "sampling_params_list": sampling_params_list # Optional, it has a default setting in stage_configs of the corresponding model. + } + + if args.query_type == "use_audio_in_video": + extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True} + + if args.modalities is not None: + output_modalities = args.modalities.split(",") + else: + output_modalities = None + + # Test multiple concurrent completions + num_concurrent_requests = args.num_concurrent_requests + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor: + # Submit multiple completion requests concurrently + futures = [ + executor.submit( + client.chat.completions.create, + messages=[ + get_system_prompt(), + prompt, + ], + model=model_name, + modalities=output_modalities, + extra_body=extra_body, + stream=args.stream, + ) + for _ in range(num_concurrent_requests) + ] + + # Wait for all requests to complete and collect results + chat_completions = [future.result() for future in concurrent.futures.as_completed(futures)] + + assert len(chat_completions) == num_concurrent_requests + count = 0 + if not args.stream: + # Verify all completions succeeded + for chat_completion in chat_completions: + for choice in chat_completion.choices: + if choice.message.audio: + audio_data = base64.b64decode(choice.message.audio.data) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"Audio saved to {audio_file_path}") + count += 1 + elif choice.message.content: + print("Chat completion output from text:", choice.message.content) + else: + printed_content = False + for chat_completion in chat_completions: + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + if getattr(chunk, "modality", None) == "audio" and content: + audio_data = base64.b64decode(content) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"\nAudio saved to {audio_file_path}") + count += 1 + + elif getattr(chunk, "modality", None) == "text": + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + print(content, end="", flush=True) + + +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="use_mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="Qwen/Qwen3-Omni-30B-A3B-Instruct", + help="Model Name / Path", + ) + parser.add_argument( + "--video-path", + "-v", + type=str, + default=None, + help="Path to local video file or URL. If not provided and query-type is 'use_video', uses default video URL.", + ) + parser.add_argument( + "--image-path", + "-i", + type=str, + default=None, + help="Path to local image file or URL. If not provided and query-type is 'use_image', uses default image URL.", + ) + parser.add_argument( + "--audio-path", + "-a", + type=str, + default=None, + help="Path to local audio file or URL. If not provided and query-type is 'use_audio', uses default audio URL.", + ) + parser.add_argument( + "--prompt", + "-p", + type=str, + default=None, + help="Custom text prompt/question to use instead of the default prompt for the selected query type.", + ) + parser.add_argument( + "--modalities", + type=str, + default=None, + help="Output modalities to use for the prompts.", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream the response.", + ) + parser.add_argument( + "--num-concurrent-requests", + type=int, + default=1, + help="Number of concurrent requests to send. Default is 1.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + run_multimodal_generation(args) diff --git a/examples/online_serving/qwen3_omni/qwen3_omni_moe_thinking.yaml b/examples/online_serving/qwen3_omni/qwen3_omni_moe_thinking.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34b63e26eaed9db48d5814bba418105809aa0a2c --- /dev/null +++ b/examples/online_serving/qwen3_omni/qwen3_omni_moe_thinking.yaml @@ -0,0 +1,36 @@ +# Stage config for running Qwen3-Omni-MoE-Thinking (text-only output) +# This config is for models like Qwen3-Omni-30B-A3B-Thinking that only have the +# thinker component and do not support audio output. +# +# Single stage: Thinker (multimodal understanding + text generation) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 2 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 diff --git a/examples/online_serving/qwen3_omni/run_curl_multimodal_generation.sh b/examples/online_serving/qwen3_omni/run_curl_multimodal_generation.sh new file mode 100644 index 0000000000000000000000000000000000000000..e5d1f38c01c1ac3745f3212e0bf55b22bf959401 --- /dev/null +++ b/examples/online_serving/qwen3_omni/run_curl_multimodal_generation.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Default query type +QUERY_TYPE="${1:-use_video}" + +# Default modalities argument +MODALITIES="${2:-null}" + +# Validate query type +if [[ ! "$QUERY_TYPE" =~ ^(text|use_audio|use_image|use_video)$ ]]; then + echo "Error: Invalid query type '$QUERY_TYPE'" + echo "Usage: $0 [text|use_audio|use_image|use_video] [modalities]" + echo " text: Text query" + echo " use_audio: Audio + Text query" + echo " use_image: Image + Text query" + echo " use_video: Video + Text query" + echo " modalities: Modalities parameter (default: null)" + exit 1 +fi + +SEED=42 + +thinker_sampling_params='{ + "temperature": 0.4, + "top_p": 0.9, + "top_k": 1, + "max_tokens": 16384, + "seed": 42, + "repetition_penalty": 1.05, + "stop_token_ids": [151645] +}' + +talker_sampling_params='{ + "temperature": 0.9, + "top_k": 50, + "max_tokens": 4096, + "seed": 42, + "detokenize": false, + "repetition_penalty": 1.05, + "stop_token_ids": [2150] +}' + +code2wav_sampling_params='{ + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 65536, + "seed": 42, + "detokenize": true, + "repetition_penalty": 1.1 +}' +# Above is optional, it has a default setting in stage_configs of the corresponding model. + +# Define URLs for assets +MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg" +CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" +SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + +# Build user content and extra fields based on query type +case "$QUERY_TYPE" in + text) + user_content='[ + { + "type": "text", + "text": "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + use_audio) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "What is the content of this audio?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + use_image) + user_content='[ + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "text", + "text": "What is the content of this image?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + use_video) + user_content='[ + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "Why is this video funny?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; +esac + +echo "Running query type: $QUERY_TYPE" +echo "" + + +output=$(curl -sS -X POST http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @- <<EOF +{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "sampling_params_list": $sampling_params_list, + "mm_processor_kwargs": $mm_processor_kwargs, + "modalities": $MODALITIES, + "messages": [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + } + ] + }, + { + "role": "user", + "content": $user_content + } + ] +} +EOF + ) + +# Here it only shows the text content of the first choice. Audio content has many binaries, so it's not displayed here. +echo "Output of request: $(echo "$output" | jq '.choices[0].message.content')" diff --git a/examples/online_serving/qwen3_omni/run_gradio_demo.sh b/examples/online_serving/qwen3_omni/run_gradio_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..73ce273d9a917a850da7d5adb3645935419df413 --- /dev/null +++ b/examples/online_serving/qwen3_omni/run_gradio_demo.sh @@ -0,0 +1,212 @@ +#!/bin/bash +# Convenience script to launch both vLLM server and Gradio demo for Qwen3-Omni +# +# Usage: +# ./run_gradio_demo.sh [OPTIONS] +# +# Example: +# ./run_gradio_demo.sh --model Qwen/Qwen3-Omni-30B-A3B-Instruct --server-port 8091 --gradio-port 7861 + +set -e + +# Default values +MODEL="Qwen/Qwen3-Omni-30B-A3B-Instruct" +SERVER_PORT=8091 +GRADIO_PORT=7861 +STAGE_CONFIGS_PATH="" +SERVER_HOST="0.0.0.0" +GRADIO_IP="127.0.0.1" +GRADIO_SHARE=false + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --server-port) + SERVER_PORT="$2" + shift 2 + ;; + --gradio-port) + GRADIO_PORT="$2" + shift 2 + ;; + --stage-configs-path) + STAGE_CONFIGS_PATH="$2" + shift 2 + ;; + --server-host) + SERVER_HOST="$2" + shift 2 + ;; + --gradio-ip) + GRADIO_IP="$2" + shift 2 + ;; + --share) + GRADIO_SHARE=true + shift + ;; + --help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --model MODEL Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct)" + echo " --server-port PORT Port for vLLM server (default: 8091)" + echo " --gradio-port PORT Port for Gradio demo (default: 7861)" + echo " --stage-configs-path PATH Path to custom stage configs YAML file (optional)" + echo " --server-host HOST Host for vLLM server (default: 0.0.0.0)" + echo " --gradio-ip IP IP for Gradio demo (default: 127.0.0.1)" + echo " --share Share Gradio demo publicly" + echo " --help Show this help message" + echo "" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Get the directory where this script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +API_BASE="http://localhost:${SERVER_PORT}/v1" +HEALTH_URL="http://localhost:${SERVER_PORT}/health" + +echo "==========================================" +echo "Starting vLLM-Omni Gradio Demo" +echo "==========================================" +echo "Model: $MODEL" +echo "Server: http://${SERVER_HOST}:${SERVER_PORT}" +echo "Gradio: http://${GRADIO_IP}:${GRADIO_PORT}" +echo "==========================================" + +# Build vLLM server command +SERVER_CMD=("vllm" "serve" "$MODEL" "--omni" "--port" "$SERVER_PORT" "--host" "$SERVER_HOST") +if [ -n "$STAGE_CONFIGS_PATH" ]; then + SERVER_CMD+=("--stage-configs-path" "$STAGE_CONFIGS_PATH") +fi + +# Function to cleanup on exit +cleanup() { + echo "" + echo "Shutting down..." + if [ -n "$SERVER_PID" ]; then + echo "Stopping vLLM server (PID: $SERVER_PID)..." + kill "$SERVER_PID" 2>/dev/null || true + wait "$SERVER_PID" 2>/dev/null || true + fi + if [ -n "$GRADIO_PID" ]; then + echo "Stopping Gradio demo (PID: $GRADIO_PID)..." + kill "$GRADIO_PID" 2>/dev/null || true + wait "$GRADIO_PID" 2>/dev/null || true + fi + echo "Cleanup complete" + exit 0 +} + +# Set up signal handlers +trap cleanup SIGINT SIGTERM + +# Start vLLM server with output shown in real-time and saved to log +echo "" +echo "Starting vLLM server..." +LOG_FILE="/tmp/vllm_server_${SERVER_PORT}.log" +"${SERVER_CMD[@]}" 2>&1 | tee "$LOG_FILE" & +SERVER_PID=$! + +# Start a background process to monitor the log for startup completion +STARTUP_COMPLETE=false +TAIL_PID="" + +# Function to cleanup tail process +cleanup_tail() { + if [ -n "$TAIL_PID" ]; then + kill "$TAIL_PID" 2>/dev/null || true + wait "$TAIL_PID" 2>/dev/null || true + fi +} + +# Wait for server to be ready by checking log output +echo "" +echo "Waiting for vLLM server to be ready (checking for 'Application startup complete' message)..." +echo "" + +# Monitor log file for startup completion message +MAX_WAIT=300 # 5 minutes timeout as fallback +ELAPSED=0 + +# Use a temporary file to track startup completion +STARTUP_FLAG="/tmp/vllm_startup_flag_${SERVER_PORT}.tmp" +rm -f "$STARTUP_FLAG" + +# Start monitoring in background +( + tail -f "$LOG_FILE" 2>/dev/null | grep -m 1 "Application startup complete" > /dev/null && touch "$STARTUP_FLAG" +) & +TAIL_PID=$! + +while [ $ELAPSED -lt $MAX_WAIT ]; do + # Check if startup flag file exists (startup complete) + if [ -f "$STARTUP_FLAG" ]; then + cleanup_tail + echo "" + echo "✓ vLLM server is ready!" + STARTUP_COMPLETE=true + break + fi + + # Check if server process is still running + if ! kill -0 "$SERVER_PID" 2>/dev/null; then + cleanup_tail + echo "" + echo "Error: vLLM server failed to start (process terminated)" + wait "$SERVER_PID" 2>/dev/null || true + exit 1 + fi + + sleep 1 + ELAPSED=$((ELAPSED + 1)) +done + +cleanup_tail +rm -f "$STARTUP_FLAG" + +if [ "$STARTUP_COMPLETE" != "true" ]; then + echo "" + echo "Error: vLLM server did not complete startup within ${MAX_WAIT} seconds" + kill "$SERVER_PID" 2>/dev/null || true + exit 1 +fi + +# Start Gradio demo +echo "" +echo "Starting Gradio demo..." +cd "$SCRIPT_DIR" +GRADIO_CMD=("python" "gradio_demo.py" "--model" "$MODEL" "--api-base" "$API_BASE" "--ip" "$GRADIO_IP" "--port" "$GRADIO_PORT") +if [ "$GRADIO_SHARE" = true ]; then + GRADIO_CMD+=("--share") +fi + +"${GRADIO_CMD[@]}" > /tmp/gradio_demo.log 2>&1 & +GRADIO_PID=$! + +echo "" +echo "==========================================" +echo "Both services are running!" +echo "==========================================" +echo "vLLM Server: http://${SERVER_HOST}:${SERVER_PORT}" +echo "Gradio Demo: http://${GRADIO_IP}:${GRADIO_PORT}" +echo "" +echo "Press Ctrl+C to stop both services" +echo "==========================================" +echo "" + +# Wait for either process to exit +wait $SERVER_PID $GRADIO_PID || true + +cleanup diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..00f7ec2c93800eba2f4a8199cd9afee777dace57 --- /dev/null +++ b/examples/online_serving/qwen3_tts/README.md @@ -0,0 +1,175 @@ +# Qwen3-TTS Online Serving + +This directory contains examples for running Qwen3-TTS models with vLLM-Omni's online serving API. + +## Supported Models + +| Model | Task Type | Description | +|-------|-----------|-------------| +| `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | CustomVoice | Predefined speaker voices with optional style control | +| `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` | VoiceDesign | Natural language voice style description | +| `Qwen/Qwen3-TTS-12Hz-1.7B-Base` | Base | Voice cloning from reference audio | + +## Quick Start + +### 1. Start the Server + +```bash +# CustomVoice model (default) +./run_server.sh + +# Or specify task type +./run_server.sh CustomVoice +./run_server.sh VoiceDesign +./run_server.sh Base +``` + +### 2. Run the Client + +```bash +# CustomVoice: Use predefined speaker +python openai_speech_client.py \ + --text "你好,我是通义千问" \ + --voice Vivian \ + --language Chinese + +# CustomVoice with style instruction +python openai_speech_client.py \ + --text "今天天气真好" \ + --voice Ryan \ + --instructions "用开心的语气说" + +# VoiceDesign: Describe the voice style +python openai_speech_client.py \ + --model Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \ + --task-type VoiceDesign \ + --text "哥哥,你回来啦" \ + --instructions "体现撒娇稚嫩的萝莉女声,音调偏高" + +# Base: Voice cloning +python openai_speech_client.py \ + --model Qwen/Qwen3-TTS-12Hz-1.7B-Base \ + --task-type Base \ + --text "Hello, this is a cloned voice" \ + --ref-audio /path/to/reference.wav \ + --ref-text "Original transcript of the reference audio" +``` + +### 3. Using curl + +```bash +# Simple TTS request +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Hello, how are you?", + "voice": "Vivian", + "language": "English" + }' --output output.wav + +# With style instruction +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "I am so excited!", + "voice": "Vivian", + "instructions": "Speak with great enthusiasm" + }' --output excited.wav + +# List available voices in CustomVoice models +curl http://localhost:8000/v1/audio/voices +``` + +## API Reference + +### Endpoint + +``` +POST /v1/audio/speech +``` + +This endpoint follows the [OpenAI Audio Speech API](https://platform.openai.com/docs/api-reference/audio/createSpeech) format with additional Qwen3-TTS parameters. + +### Request Body + +```json +{ + "input": "Text to synthesize", + "voice": "Vivian", + "response_format": "wav", + "task_type": "CustomVoice", + "language": "Auto", + "instructions": "Optional style instructions", + "ref_audio": "URL or base64 for voice cloning", + "ref_text": "Reference audio transcript", + "x_vector_only_mode": false, + "max_new_tokens": 2048 +} +``` + +> **Note:** The `model` field is optional when serving a single model, as the server already knows which model is loaded. + +### Response + +Returns audio data in the requested format (default: WAV). + +## Parameters + +### Standard OpenAI Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `input` | string | required | Text to synthesize | +| `voice` | string | "Vivian" | Speaker/voice name | +| `response_format` | string | "wav" | Audio format: wav, mp3, flac, pcm, aac, opus | +| `speed` | float | 1.0 | Playback speed (0.25-4.0) | +| `model` | string | optional | Model name (optional when serving single model) | + +### Qwen3-TTS Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `task_type` | string | "CustomVoice" | Task: CustomVoice, VoiceDesign, or Base | +| `language` | string | "Auto" | Language: Auto, Chinese, English, Japanese, Korean | +| `instructions` | string | "" | Voice style/emotion instructions | +| `max_new_tokens` | int | 2048 | Maximum tokens to generate | + +### Voice Clone Parameters (Base task) + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `ref_audio` | string | Yes* | Reference audio (file path, URL, or base64) | +| `ref_text` | string | No | Transcript of reference audio (for ICL mode) | +| `x_vector_only_mode` | bool | false | Use speaker embedding only (no ICL) | + +## Python Usage + +```python +import httpx + +# Simple request +response = httpx.post( + "http://localhost:8000/v1/audio/speech", + json={ + "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + "input": "Hello world", + "voice": "Vivian", + }, + timeout=300.0, +) + +with open("output.wav", "wb") as f: + f.write(response.content) +``` + +## Limitations + +- **No streaming**: Audio is generated completely before being returned. Streaming will be supported after the pipeline is disaggregated (see RFC #938). +- **Single request**: Batch processing is not yet optimized for online serving. + +## Troubleshooting + +1. **Connection refused**: Make sure the server is running on the correct port +2. **Out of memory**: Reduce `--gpu-memory-utilization` in run_server.sh +3. **Unsupported speaker**: Check supported speakers via model documentation +4. **Voice clone fails**: Ensure you're using the Base model variant for voice cloning diff --git a/examples/online_serving/qwen3_tts/openai_speech_client.py b/examples/online_serving/qwen3_tts/openai_speech_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4d70460940c7ed02b75f7eb90b694d27de3101d3 --- /dev/null +++ b/examples/online_serving/qwen3_tts/openai_speech_client.py @@ -0,0 +1,240 @@ +"""OpenAI-compatible client for Qwen3-TTS via /v1/audio/speech endpoint. + +This script demonstrates how to use the OpenAI-compatible speech API +to generate audio from text using Qwen3-TTS models. + +Examples: + # CustomVoice task (predefined speaker) + python openai_speech_client.py --text "Hello, how are you?" --voice Vivian + + # CustomVoice with emotion instruction + python openai_speech_client.py --text "I'm so happy!" --voice Vivian \ + --instructions "Speak with excitement" + + # VoiceDesign task (voice from description) + python openai_speech_client.py --text "Hello world" \ + --task-type VoiceDesign \ + --instructions "A warm, friendly female voice" + + # Base task (voice cloning) + python openai_speech_client.py --text "Hello world" \ + --task-type Base \ + --ref-audio "https://example.com/reference.wav" \ + --ref-text "This is the reference transcript" +""" + +import argparse +import base64 +import os + +import httpx + +# Default server configuration +DEFAULT_API_BASE = "http://localhost:8000" +DEFAULT_API_KEY = "EMPTY" + + +def encode_audio_to_base64(audio_path: str) -> str: + """Encode a local audio file to base64 data URL.""" + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + # Detect MIME type from extension + audio_path_lower = audio_path.lower() + if audio_path_lower.endswith(".wav"): + mime_type = "audio/wav" + elif audio_path_lower.endswith((".mp3", ".mpeg")): + mime_type = "audio/mpeg" + elif audio_path_lower.endswith(".flac"): + mime_type = "audio/flac" + elif audio_path_lower.endswith(".ogg"): + mime_type = "audio/ogg" + else: + mime_type = "audio/wav" # Default + + with open(audio_path, "rb") as f: + audio_bytes = f.read() + audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + return f"data:{mime_type};base64,{audio_b64}" + + +def run_tts_generation(args) -> None: + """Run TTS generation via OpenAI-compatible /v1/audio/speech API.""" + + # Build request payload + payload = { + "model": args.model, + "input": args.text, + "voice": args.voice, + "response_format": args.response_format, + } + + # Add optional parameters + if args.instructions: + payload["instructions"] = args.instructions + if args.task_type: + payload["task_type"] = args.task_type + if args.language: + payload["language"] = args.language + if args.max_new_tokens: + payload["max_new_tokens"] = args.max_new_tokens + + # Voice clone parameters (Base task) + if args.ref_audio: + if args.ref_audio.startswith(("http://", "https://")): + payload["ref_audio"] = args.ref_audio + else: + payload["ref_audio"] = encode_audio_to_base64(args.ref_audio) + if args.ref_text: + payload["ref_text"] = args.ref_text + if args.x_vector_only: + payload["x_vector_only_mode"] = True + + print(f"Model: {args.model}") + print(f"Task type: {args.task_type or 'CustomVoice'}") + print(f"Text: {args.text}") + print(f"Voice: {args.voice}") + print("Generating audio...") + + # Make the API call + api_url = f"{args.api_base}/v1/audio/speech" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {args.api_key}", + } + + with httpx.Client(timeout=300.0) as client: + response = client.post(api_url, json=payload, headers=headers) + + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.text) + return + + if response.content.decode("utf-8").startswith('{"error"'): + print(f"Error: {response.content.decode('utf-8')}") + return + + # Save audio response + output_path = args.output or "tts_output.wav" + with open(output_path, "wb") as f: + f.write(response.content) + print(f"Audio saved to: {output_path}") + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="OpenAI-compatible client for Qwen3-TTS via /v1/audio/speech", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Server configuration + parser.add_argument( + "--api-base", + type=str, + default=DEFAULT_API_BASE, + help=f"API base URL (default: {DEFAULT_API_BASE})", + ) + parser.add_argument( + "--api-key", + type=str, + default=DEFAULT_API_KEY, + help="API key (default: EMPTY)", + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + help="Model name/path", + ) + + # Task configuration + parser.add_argument( + "--task-type", + "-t", + type=str, + default=None, + choices=["CustomVoice", "VoiceDesign", "Base"], + help="TTS task type (default: CustomVoice)", + ) + + # Input text + parser.add_argument( + "--text", + type=str, + required=True, + help="Text to synthesize", + ) + + # Voice/speaker + parser.add_argument( + "--voice", + type=str, + default="Vivian", + help="Speaker/voice name (default: Vivian). Options: Vivian, Ryan, etc.", + ) + parser.add_argument( + "--language", + type=str, + default=None, + help="Language: Auto, Chinese, English, etc.", + ) + parser.add_argument( + "--instructions", + type=str, + default=None, + help="Voice style/emotion instructions", + ) + + # Base (voice clone) parameters + parser.add_argument( + "--ref-audio", + type=str, + default=None, + help="Reference audio file path or URL for voice cloning (Base task)", + ) + parser.add_argument( + "--ref-text", + type=str, + default=None, + help="Reference audio transcript for voice cloning (Base task)", + ) + parser.add_argument( + "--x-vector-only", + action="store_true", + help="Use x-vector only mode for voice cloning (no ICL)", + ) + + # Generation parameters + parser.add_argument( + "--max-new-tokens", + type=int, + default=None, + help="Maximum new tokens to generate", + ) + + # Output + parser.add_argument( + "--response-format", + type=str, + default="wav", + choices=["wav", "mp3", "flac", "pcm", "aac", "opus"], + help="Audio output format (default: wav)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="Output audio file path (default: tts_output.wav)", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + run_tts_generation(args) diff --git a/examples/online_serving/qwen3_tts/run_server.sh b/examples/online_serving/qwen3_tts/run_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..203ed76e8d394b06771a30ab79703dd24431db99 --- /dev/null +++ b/examples/online_serving/qwen3_tts/run_server.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Launch vLLM-Omni server for Qwen3-TTS models +# +# Usage: +# ./run_server.sh # Default: CustomVoice model +# ./run_server.sh CustomVoice # CustomVoice model +# ./run_server.sh VoiceDesign # VoiceDesign model +# ./run_server.sh Base # Base (voice clone) model + +set -e + +TASK_TYPE="${1:-CustomVoice}" + +case "$TASK_TYPE" in + CustomVoice) + MODEL="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" + ;; + VoiceDesign) + MODEL="Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" + ;; + Base) + MODEL="Qwen/Qwen3-TTS-12Hz-1.7B-Base" + ;; + *) + echo "Unknown task type: $TASK_TYPE" + echo "Supported: CustomVoice, VoiceDesign, Base" + exit 1 + ;; +esac + +echo "Starting Qwen3-TTS server with model: $MODEL" + +vllm-omni serve "$MODEL" \ + --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --host 0.0.0.0 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --omni diff --git a/examples/online_serving/text_to_image/README.md b/examples/online_serving/text_to_image/README.md new file mode 100644 index 0000000000000000000000000000000000000000..744b7b2921d6b0722b4d3e275e15c374b4d36b69 --- /dev/null +++ b/examples/online_serving/text_to_image/README.md @@ -0,0 +1,159 @@ +# Text-To-Image + +This example demonstrates how to deploy Qwen-Image model for online image generation service using vLLM-Omni. + +## Start Server + +### Basic Start + +```bash +vllm serve Qwen/Qwen-Image --omni --port 8091 +``` +!!! note + If you encounter Out-of-Memory (OOM) issues or have limited GPU memory, you can enable VAE slicing and tiling to reduce memory usage, --vae-use-slicing --vae-use-tiling + +### Start with Parameters + +Or use the startup script: + +```bash +bash run_server.sh +``` + +## API Calls + +### Method 1: Using curl + +```bash +# Basic text-to-image generation +bash run_curl_text_to_image.sh + +# Or execute directly +curl -s http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "A beautiful landscape painting"} + ], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "true_cfg_scale": 4.0, + "seed": 42 + } + }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png +``` + +### Method 2: Using Python Client + +```bash +python openai_chat_client.py --prompt "A beautiful landscape painting" --output output.png +``` + +### Method 3: Using Gradio Demo + +```bash +python gradio_demo.py +# Visit http://localhost:7860 +``` + +## Request Format + +### Simple Text Generation + +```json +{ + "messages": [ + {"role": "user", "content": "A beautiful landscape painting"} + ] +} +``` + +### Generation with Parameters + +Use `extra_body` to pass generation parameters: + +```json +{ + "messages": [ + {"role": "user", "content": "A beautiful landscape painting"} + ], + "extra_body": { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "true_cfg_scale": 4.0, + "seed": 42 + } +} +``` + +### Multimodal Input (Text + Structured Content) + +```json +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "A beautiful landscape painting"} + ] + } + ] +} +``` + +## Generation Parameters (extra_body) + +| Parameter | Type | Default | Description | +| ------------------------ | ----- | ------- | ------------------------------ | +| `height` | int | None | Image height in pixels | +| `width` | int | None | Image width in pixels | +| `size` | str | None | Image size (e.g., "1024x1024") | +| `num_inference_steps` | int | 50 | Number of denoising steps | +| `true_cfg_scale` | float | 4.0 | Qwen-Image CFG scale | +| `seed` | int | None | Random seed (reproducible) | +| `negative_prompt` | str | None | Negative prompt | +| `num_outputs_per_prompt` | int | 1 | Number of images to generate | +| `--cfg-parallel-size`. | int | 1 | Number of GPUs for CFG parallelism | + +## Response Format + +```json +{ + "id": "chatcmpl-xxx", + "created": 1234567890, + "model": "Qwen/Qwen-Image", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": [{ + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,..." + } + }] + }, + "finish_reason": "stop" + }], + "usage": {...} +} +``` + +## Extract Image + +```bash +# Extract base64 from response and decode to image +cat response.json | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png +``` + +## File Description + +| File | Description | +| --------------------------- | ---------------------------- | +| `run_server.sh` | Server startup script | +| `run_curl_text_to_image.sh` | curl example | +| `openai_chat_client.py` | Python client | +| `gradio_demo.py` | Gradio interactive interface | diff --git a/examples/online_serving/text_to_image/gradio_demo.py b/examples/online_serving/text_to_image/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..608db9a23e6dc735ac715e9ef246c38838a31f40 --- /dev/null +++ b/examples/online_serving/text_to_image/gradio_demo.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +""" +Qwen-Image Gradio Demo for online serving. + +Usage: + python gradio_demo.py [--server http://localhost:8091] [--port 7860] +""" + +import argparse +import base64 +from io import BytesIO + +import gradio as gr +import requests +from PIL import Image + + +def generate_image( + prompt: str, + height: int, + width: int, + steps: int, + cfg_scale: float, + seed: int | None, + negative_prompt: str, + server_url: str, + num_outputs_per_prompt: int = 1, +) -> Image.Image | None: + """Generate an image using the chat completions API.""" + messages = [{"role": "user", "content": prompt}] + + # Build extra_body with generation parameters + extra_body = { + "height": height, + "width": width, + "num_inference_steps": steps, + "true_cfg_scale": cfg_scale, + } + if seed is not None and seed >= 0: + extra_body["seed"] = seed + if negative_prompt: + extra_body["negative_prompt"] = negative_prompt + # Keep consistent with run_curl_text_to_image.sh, always send num_outputs_per_prompt + extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt + + # Build request payload + payload = {"messages": messages, "extra_body": extra_body} + + try: + response = requests.post( + f"{server_url}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=300, + ) + response.raise_for_status() + data = response.json() + + content = data["choices"][0]["message"]["content"] + if isinstance(content, list) and len(content) > 0: + image_url = content[0].get("image_url", {}).get("url", "") + if image_url.startswith("data:image"): + _, b64_data = image_url.split(",", 1) + image_bytes = base64.b64decode(b64_data) + return Image.open(BytesIO(image_bytes)) + + return None + + except Exception as e: + print(f"Error: {e}") + raise gr.Error(f"Generation failed: {e}") + + +def create_demo(server_url: str): + """Create Gradio demo interface.""" + + with gr.Blocks(title="Qwen-Image Demo") as demo: + gr.Markdown("# Qwen-Image Online Generation") + gr.Markdown("Generate images using Qwen-Image model") + + with gr.Row(): + with gr.Column(scale=1): + prompt = gr.Textbox( + label="Prompt", + placeholder="Describe the image you want to generate...", + lines=3, + ) + negative_prompt = gr.Textbox( + label="Negative Prompt", + placeholder="Describe what you don't want...", + lines=2, + ) + + with gr.Row(): + height = gr.Slider( + label="Height", + minimum=256, + maximum=2048, + value=1024, + step=64, + ) + width = gr.Slider( + label="Width", + minimum=256, + maximum=2048, + value=1024, + step=64, + ) + + with gr.Row(): + steps = gr.Slider( + label="Inference Steps", + minimum=10, + maximum=100, + # Default steps aligned with run_curl_text_to_image.sh to 100 + value=100, + step=5, + ) + cfg_scale = gr.Slider( + label="True CFG Scale", + minimum=1.0, + maximum=20.0, + value=4.0, + step=0.5, + ) + + with gr.Row(): + seed = gr.Number( + label="Random Seed (-1 for random)", + value=-1, + precision=0, + ) + + generate_btn = gr.Button("Generate Image", variant="primary") + + with gr.Column(scale=1): + output_image = gr.Image( + label="Generated Image", + type="pil", + ) + + # Examples + gr.Examples( + examples=[ + ["A beautiful landscape painting with misty mountains", "", 1024, 1024, 100, 4.0, 42], + ["A cute cat sitting on a windowsill with sunlight", "", 1024, 1024, 100, 4.0, 123], + ["Cyberpunk style futuristic city with neon lights", "blurry, low quality", 1024, 768, 100, 4.0, 456], + ["Chinese ink painting of bamboo forest with a house", "", 768, 1024, 100, 4.0, 789], + ], + inputs=[prompt, negative_prompt, height, width, steps, cfg_scale, seed], + ) + + generate_btn.click( + fn=lambda p, h, w, st, c, se, n: generate_image( + p, + h, + w, + st, + c, + se if se >= 0 else None, + n, + server_url, + 1, + ), + inputs=[prompt, height, width, steps, cfg_scale, seed, negative_prompt], + outputs=[output_image], + ) + + return demo + + +def main(): + parser = argparse.ArgumentParser(description="Qwen-Image Gradio Demo") + parser.add_argument("--server", default="http://localhost:8091", help="Server URL") + parser.add_argument("--port", type=int, default=7860, help="Gradio port") + parser.add_argument("--share", action="store_true", help="Create public link") + + args = parser.parse_args() + + print(f"Connecting to server: {args.server}") + demo = create_demo(args.server) + demo.launch(server_port=args.port, share=args.share) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/text_to_image/openai_chat_client.py b/examples/online_serving/text_to_image/openai_chat_client.py new file mode 100644 index 0000000000000000000000000000000000000000..7beac1accdabc0ed3b92baa61763b00e92e28e4a --- /dev/null +++ b/examples/online_serving/text_to_image/openai_chat_client.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +""" +Qwen-Image OpenAI-compatible image generation client. + +Usage: + python openai_chat_client.py --prompt "A beautiful landscape" --output output.png + python openai_chat_client.py --prompt "A sunset" --height 1024 --width 1024 --steps 50 --seed 42 +""" + +import argparse +import base64 +from pathlib import Path + +import requests + + +def generate_image( + prompt: str, + server_url: str = "http://localhost:8091", + height: int | None = None, + width: int | None = None, + steps: int | None = None, + true_cfg_scale: float | None = None, + seed: int | None = None, + negative_prompt: str | None = None, + num_outputs_per_prompt: int = 1, +) -> bytes | None: + """Generate an image using the images generation API. + + Args: + prompt: Text description of the image + server_url: Server URL + height: Image height in pixels + width: Image width in pixels + steps: Number of diffusion steps + true_cfg_scale: Qwen-Image CFG scale + seed: Random seed + negative_prompt: Negative prompt + num_outputs_per_prompt: Number of images to generate + + Returns: + Image bytes or None if failed + """ + payload: dict[str, object] = { + "prompt": prompt, + "response_format": "b64_json", + "n": num_outputs_per_prompt, + } + + if width is not None and height is not None: + payload["size"] = f"{width}x{height}" + elif width is not None: + payload["size"] = f"{width}x{width}" + elif height is not None: + payload["size"] = f"{height}x{height}" + + if steps is not None: + payload["num_inference_steps"] = steps + if true_cfg_scale is not None: + payload["true_cfg_scale"] = true_cfg_scale + if negative_prompt: + payload["negative_prompt"] = negative_prompt + if seed is not None: + payload["seed"] = seed + + try: + response = requests.post( + f"{server_url}/v1/images/generations", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=300, + ) + response.raise_for_status() + data = response.json() + + items = data.get("data") + if isinstance(items, list) and items: + first = items[0].get("b64_json") if isinstance(items[0], dict) else None + if isinstance(first, str): + return base64.b64decode(first) + + print(f"Unexpected response format: {data}") + return None + + except Exception as e: + print(f"Error: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser(description="Qwen-Image chat client") + parser.add_argument("--prompt", "-p", default="a cup of coffee on the table", help="Text prompt") + parser.add_argument("--output", "-o", default="qwen_image_output.png", help="Output file") + parser.add_argument("--server", "-s", default="http://localhost:8091", help="Server URL") + parser.add_argument("--height", type=int, default=1024, help="Image height") + parser.add_argument("--width", type=int, default=1024, help="Image width") + parser.add_argument("--steps", type=int, default=50, help="Inference steps") + parser.add_argument("--cfg-scale", type=float, default=4.0, help="True CFG scale") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--negative", help="Negative prompt") + + args = parser.parse_args() + + print(f"Generating image for: {args.prompt}") + + image_bytes = generate_image( + prompt=args.prompt, + server_url=args.server, + height=args.height, + width=args.width, + steps=args.steps, + true_cfg_scale=args.cfg_scale, + seed=args.seed, + negative_prompt=args.negative, + ) + + if image_bytes: + output_path = Path(args.output) + output_path.write_bytes(image_bytes) + print(f"Image saved to: {output_path}") + print(f"Size: {len(image_bytes) / 1024:.1f} KB") + else: + print("Failed to generate image") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/text_to_image/run_curl_text_to_image.sh b/examples/online_serving/text_to_image/run_curl_text_to_image.sh new file mode 100644 index 0000000000000000000000000000000000000000..151df89485caa277fe0b01c55c034e946bad17fe --- /dev/null +++ b/examples/online_serving/text_to_image/run_curl_text_to_image.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Qwen-Image text-to-image curl example + +curl -X POST http://localhost:8091/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "a dragon laying over the spine of the Green Mountains of Vermont", + "size": "1024x1024", + "seed": 42 + }' | jq -r '.data[0].b64_json' | base64 -d > dragon.png diff --git a/examples/online_serving/text_to_image/run_server.sh b/examples/online_serving/text_to_image/run_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..b25337a0e930a8663b5b2dbb467cc8b4cb4e7958 --- /dev/null +++ b/examples/online_serving/text_to_image/run_server.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Qwen-Image online serving startup script + +MODEL="${MODEL:-Qwen/Qwen-Image}" +PORT="${PORT:-8091}" + +echo "Starting Qwen-Image server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +vllm serve "$MODEL" --omni \ + --port "$PORT" diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000000000000000000000000000000000000..15a07450cb2d7b6f49f90cd32001ea1962bb6509 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,168 @@ +site_name: vLLM-Omni +site_description: Efficient omni-modality model serving for everyone +site_author: vLLM-Omni Team +site_url: https://vllm-project.github.io/vllm-omni/ + +repo_name: vllm-project/vllm-omni +repo_url: https://github.com/vllm-project/vllm-omni +edit_uri: edit/main/docs/ + +# Copyright +copyright: Copyright © 2025 vLLM-Omni Team + +# Theme +theme: + name: material + logo: source/logos/vllm-logo-only-light.ico + favicon: source/logos/vllm-logo-only-light.ico + palette: + # Palette toggle for automatic mode + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to light mode + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + primary: white + toggle: + icon: material/brightness-7 + name: Switch to dark mode + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: black + toggle: + icon: material/brightness-2 + name: Switch to system preference + features: + - content.action.edit + - content.code.copy + - navigation.instant + - navigation.instant.progress + - navigation.tracking + - navigation.tabs + - navigation.tabs.sticky + - navigation.sections + - navigation.indexes + - navigation.top + - search.suggest + - search.highlight + - search.share + - content.code.annotate + - content.tabs + - content.tooltips + - toc.follow + custom_dir: docs/mkdocs/overrides + +hooks: + - docs/mkdocs/hooks/generate_api_readme.py + - docs/mkdocs/hooks/url_schemes.py + - docs/mkdocs/hooks/generate_examples.py + +# Exclude include files from navigation warnings +exclude_docs: | + **/*.inc.md + +# Plugins +plugins: + - meta + - search + - autorefs + - awesome-nav + - glightbox + - git-revision-date-localized: + # exclude files + exclude: + - api/* + - user_guide/examples/** + - contributing/design_documents/api_design_template.md + - DOCS_GUIDE.md + - minify: + minify_html: true + minify_js: true + minify_css: true + cache_safe: true + js_files: [docs/mkdocs/javascript/*.js] + css_files: [docs/mkdocs/stylesheets/*.css] + - api-autonav: + modules: ["vllm_omni"] + api_root_uri: "api" + nav_item_prefix: "" # No prefix in navigation tree (clean names) + show_full_namespace: false # Show only module name, not full path + on_implicit_namespace_package: skip # Skip directories without __init__.py (e.g., assets) + exclude: + - "re:vllm_omni\\._.*" # Internal modules + - "vllm_omni.diffusion.models.qwen_image" # avoid importing vllm in mkdocs building + - "vllm_omni.entrypoints.async_diffusion" # avoid importing vllm in mkdocs building + - "vllm_omni.entrypoints.openai" # avoid importing vllm in mkdocs building + - "vllm_omni.entrypoints.openai.protocol" # avoid importing vllm in mkdocs building + - mkdocstrings: + handlers: + python: + options: + show_symbol_type_heading: true + show_symbol_type_toc: true + filters: + - "!^_" # Exclude private members (methods/classes starting with underscore) + summary: + modules: true + show_if_no_docstring: true + show_signature_annotations: true + separate_signature: true + show_overloads: true + signature_crossrefs: true + inventories: + - https://docs.python.org/3/objects.inv + - https://typing-extensions.readthedocs.io/en/latest/objects.inv + - https://docs.aiohttp.org/en/stable/objects.inv + - https://pillow.readthedocs.io/en/stable/objects.inv + - https://numpy.org/doc/stable/objects.inv + # Temporarily disabled due to decompression errors + # - https://pytorch.org/docs/stable/objects.inv + - https://psutil.readthedocs.io/en/stable/objects.inv + +markdown_extensions: + - attr_list + - md_in_html + - admonition + - pymdownx.details + # For content tabs + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.tabbed: + slugify: !!python/object/apply:pymdownx.slugs.slugify + kwds: + case: lower + alternate_style: true + # For code highlighting + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.snippets + # For emoji and icons + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + # For in page [TOC] (not sidebar) + - toc: + permalink: true + # For math rendering + - pymdownx.arithmatex: + generic: true + +extra_css: + - mkdocs/stylesheets/extra.css + +extra_javascript: + - mkdocs/javascript/mathjax.js + - https://unpkg.com/mathjax@3.2.2/es5/tex-mml-chtml.js + - https://unpkg.com/mermaid@10/dist/mermaid.min.js + - mkdocs/javascript/mermaid.js + - mkdocs/javascript/edit_and_feedback.js + - mkdocs/javascript/slack_and_forum.js diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..25454102f0f868747c0ac022b7f0215f618cdc00 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,207 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "vllm-omni" +version = "0.14.0" +description = "A framework for efficient model inference with omni-modality models" +readme = "README.md" +requires-python = ">=3.10,<3.14" +license = {text = "Apache-2.0"} +authors = [ + {name = "vLLM-Omni Team"} +] +keywords = ["vllm", "multimodal", "diffusion", "transformer", "inference", "serving"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] + + +dependencies = [ + # Core runtime dependencies (required for actual usage) + "omegaconf>=2.3.0", + "librosa>=0.11.0", + "resampy>=0.4.3", + "diffusers>=0.36.0", + "accelerate==1.12.0", + "gradio==5.50", + "soundfile>=0.13.1", + "cache-dit==1.2.0", + "tqdm>=4.66.0", + "torchsde>=0.2.6", # Required for Stable Audio scheduler + "fa3-fwd==0.0.1", # flash attention 3, maintained by @ZJY0516 + "openai-whisper>=20250625", + "imageio[ffmpeg]>=2.37.2", + "onnxruntime>=1.19.0", + "sox>=1.5.0", + # "vllm==0.14.0", # TODO: fix the entrypoints overwrite problem +] + +[project.optional-dependencies] + +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.0.0", + "mypy==1.11.1", + "pre-commit==4.0.1", + "openai-whisper>=20250625", + "psutil>=7.2.0", + "soundfile>=0.13.1", + "imageio[ffmpeg]>=0.6.0", + "opencv-python>=4.12.0.88", + "mooncake-transfer-engine==0.3.8.post1" +] + +docs = [ + "mkdocs>=1.5.0", + "mkdocs-api-autonav", + "mkdocs-material", + "mkdocstrings-python", + "mkdocs-gen-files", + "mkdocs-awesome-nav", + "mkdocs-glightbox", + "mkdocs-git-revision-date-localized-plugin", + "mkdocs-minify-plugin", + "regex", + "ruff", + "pydantic", +] + + +[project.urls] +Homepage = "https://github.com/vllm-project/vllm-omni" +Repository = "https://github.com/vllm-project/vllm-omni" +Documentation = "https://vllm-omni.readthedocs.io" +"Bug Tracker" = "https://github.com/vllm-project/vllm-omni/issues" + +[project.scripts] +vllm = "vllm_omni.entrypoints.cli.main:main" +vllm-omni = "vllm_omni.entrypoints.cli.main:main" + + +[tool.setuptools.packages.find] +where = ["."] +include = ["vllm_omni*"] + +[tool.setuptools.package-data] +"vllm_omni.model_executor.stage_configs" = ["*.yaml"] + +[tool.ruff] +line-length = 120 +exclude = [ + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".tox", + ".venv", + "build", + "dist", + "vllm_omni.egg-info", +] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort (handled separately, but included for compatibility) + "N", # pep8-naming + "UP", # pyupgrade +] +ignore = [ + "E203", # whitespace before ':' (conflicts with black) + # W503 is not needed in ruff as it's compatible with black by default + "N801", # class names should use CapWords convention + "N802", # function name should follow snake_case + "N806", # variable in function should follow snake_case + "N812", # lowercase imported as non-lowercase: functional as F +] + +[tool.ruff.lint.per-file-ignores] +"examples/**" = ["E501"] # Allow long lines in examples +"tests/**" = ["E501"] # Allow long lines in tests + +[tool.mypy] +python_version = "3.12, 3.13" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "--strict-markers", + "--strict-config", + "--cov=vllm_omni", + "--cov-report=term-missing", + "--cov-report=html", + "--cov-report=xml", +] +markers = [ + # ci/cd required + "core_model: Core model tests (run in each PR)", + # function module markers + "diffusion: Diffusion model tests", + "omni: Omni model tests", + "cache: Cache backend tests", + "parallel: Parallelism/distributed tests", + # platform markers + "cpu: Tests that run on CPU", + "gpu: Tests that run on GPU (auto-added)", + "cuda: Tests that run on CUDA (auto-added)", + "rocm: Tests that run on AMD/ROCm (auto-added)", + "npu: Tests that run on NPU/Ascend (auto-added)", + # specified computation resources marks (auto-added) + "H100: Tests that require H100 GPU", + "L4: Tests that require L4 GPU", + "MI325: Tests that require MI325 GPU (AMD/ROCm)", + "A2: Tests that require A2 NPU", + "A3: Tests that require A3 NPU", + "distributed_cuda: Tests that require multi cards on CUDA platform", + "distributed_rocm: Tests that require multi cards on ROCm platform", + "distributed_npu: Tests that require multi cards on NPU platform", + "skipif_cuda: Skip if the num of CUDA cards is less than the required", + "skipif_rocm: Skip if the num of ROCm cards is less than the required", + "skipif_npu: Skip if the num of NPU cards is less than the required", + # more detailed markers + "slow: Slow tests (may skip in quick CI)", + "benchmark: Benchmark tests", +] + +[tool.typos.default] +extend-ignore-identifiers-re = [ + ".*_thw", + ".*thw", + "ein", + ".*arange", + ".*MoBA", + ".*temperal_downsample", + ".*nothink.*", + ".*NOTHINK.*", + ".*nin.*", + "Ono_Anna", +] diff --git a/scripts/build_wheel.sh b/scripts/build_wheel.sh new file mode 100644 index 0000000000000000000000000000000000000000..54769a3079b853deb41506fdfc56805c1033bf3b --- /dev/null +++ b/scripts/build_wheel.sh @@ -0,0 +1,139 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" +PROJECT_NAME="vllm-omni" +RUN_QUALITY=false +SKIP_CLEAN=false +CREATE_VENV=false +VENV_DIR=".venv-build" +PYTHON_BIN="python" +UV_BIN="uv" + +log() { + local level="$1" + shift + printf '[%s] %s\n' "${level}" "$*" +} + +abort() { + log "ERROR" "$*" + exit 1 +} + +usage() { + cat <<EOF +Usage: $(basename "$0") [options] + +Options: + --run-quality Run pre-commit, install dev deps, and pytest before building + --skip-clean Skip removing previous build artifacts + --create-venv Build inside a fresh virtual environment (default path: .venv-build) + --venv-dir PATH Custom directory for the virtual environment (implies --create-venv) + --python PATH Python executable to use (default: python) + -h, --help Show this help message +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --run-quality) + RUN_QUALITY=true + ;; + --skip-clean) + SKIP_CLEAN=true + ;; + --create-venv) + CREATE_VENV=true + ;; + --venv-dir) + CREATE_VENV=true + shift + [[ $# -gt 0 ]] || abort "--venv-dir requires a path" + VENV_DIR="$1" + ;; + --python) + shift + [[ $# -gt 0 ]] || abort "--python requires a path" + PYTHON_BIN="$1" + ;; + -h|--help) + usage + exit 0 + ;; + *) + usage + abort "Unknown option: $1" + ;; + esac + shift +done + +HOST_PYTHON="${PYTHON_BIN}" + +log "INFO" "Switching to repository root: ${REPO_ROOT}" +cd "${REPO_ROOT}" || abort "Cannot enter repository root" + +[[ -f pyproject.toml ]] || abort "pyproject.toml not found, please ensure correct script location" + +ensure_uv() { + if ! command -v "${UV_BIN}" >/dev/null 2>&1; then + log "INFO" "uv not found, installing via ${HOST_PYTHON}" + "${HOST_PYTHON}" -m pip install --upgrade pip + "${HOST_PYTHON}" -m pip install uv + fi +} + +ensure_uv + +if [[ "${CREATE_VENV}" == "true" ]]; then + log "INFO" "Creating fresh virtual environment at ${VENV_DIR} via uv" + "${UV_BIN}" venv --python "${HOST_PYTHON}" --seed "${VENV_DIR}" + PYTHON_BIN="${VENV_DIR}/bin/python" + [[ -x "${PYTHON_BIN}" ]] || abort "Failed to locate python inside ${VENV_DIR}" + log "INFO" "Installing build module inside virtual environment" + "${UV_BIN}" pip install --python "${PYTHON_BIN}" build +else + log "INFO" "Ensuring build module is available via uv pip" + "${UV_BIN}" pip install --python "${PYTHON_BIN}" build +fi + +log "INFO" "Checking build module" +if ! "${PYTHON_BIN}" -m build --version >/dev/null 2>&1; then + abort "${PYTHON_BIN} -m build is not available, install build first" +fi + +run_quality_steps() { + log "INFO" "Running quality checks" + "${UV_BIN}" pip install --python "${PYTHON_BIN}" -e ".[dev]" + "${PYTHON_BIN}" -m pre_commit run --all-files + "${PYTHON_BIN}" -m pytest tests/ -v -m "not slow" +} + +if [[ "${RUN_QUALITY}" == "true" ]]; then + run_quality_steps +else + log "INFO" "Quality steps available via --run-quality" + log "INFO" " - pre-commit run --all-files" + log "INFO" " - pip install -e '.[dev]'" + log "INFO" " - pytest tests/ -v -m \"not slow\"" +fi + +cleanup_artifacts() { + log "INFO" "Cleaning previous build artifacts" + rm -rf build dist "${PROJECT_NAME}.egg-info" "${PROJECT_NAME//-/_}.egg-info" +} + +if [[ "${SKIP_CLEAN}" == "true" ]]; then + log "INFO" "Skipping cleanup as requested" +else + cleanup_artifacts +fi + +log "INFO" "Building source and wheel distributions" +"${PYTHON_BIN}" -m build + +log "INFO" "Build finished, artifacts:" +ls -lh dist diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bacd68e5d22bf552f1d455be9c25682509d59f7e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +""" +Test suite for vLLM-Omni. + +This package contains unit tests, integration tests, and benchmarks +for vLLM-Omni. +""" diff --git a/tests/__pycache__/__init__.cpython-310.pyc.1656210441360 b/tests/__pycache__/__init__.cpython-310.pyc.1656210441360 new file mode 100644 index 0000000000000000000000000000000000000000..e4a720065f980c0b729dfa079b668f57af538dd9 Binary files /dev/null and b/tests/__pycache__/__init__.cpython-310.pyc.1656210441360 differ diff --git a/tests/__pycache__/__init__.cpython-310.pyc.2499248702608 b/tests/__pycache__/__init__.cpython-310.pyc.2499248702608 new file mode 100644 index 0000000000000000000000000000000000000000..e4a720065f980c0b729dfa079b668f57af538dd9 Binary files /dev/null and b/tests/__pycache__/__init__.cpython-310.pyc.2499248702608 differ diff --git a/tests/__pycache__/conftest.cpython-310-pytest-8.3.5.pyc.19912 b/tests/__pycache__/conftest.cpython-310-pytest-8.3.5.pyc.19912 new file mode 100644 index 0000000000000000000000000000000000000000..d287d7d0edc6ed5d31384135d794030c64e8b729 Binary files /dev/null and b/tests/__pycache__/conftest.cpython-310-pytest-8.3.5.pyc.19912 differ diff --git a/tests/__pycache__/conftest.cpython-310-pytest-8.3.5.pyc.40548 b/tests/__pycache__/conftest.cpython-310-pytest-8.3.5.pyc.40548 new file mode 100644 index 0000000000000000000000000000000000000000..d287d7d0edc6ed5d31384135d794030c64e8b729 Binary files /dev/null and b/tests/__pycache__/conftest.cpython-310-pytest-8.3.5.pyc.40548 differ diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..2c624b8e76acaca81dc000dba55b27b6b7d4f997 --- /dev/null +++ b/tests/benchmarks/test_serve_cli.py @@ -0,0 +1,60 @@ +import subprocess +from pathlib import Path + +import pytest + +from tests.conftest import OmniServer + +models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] +stage_configs = [str(Path(__file__).parent.parent / "e2e" / "stage_configs" / "qwen3_omni_ci.yaml")] + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +@pytest.fixture(scope="module") +def omni_server(request): + """Start vLLM-Omni server as a subprocess with actual model weights. + Uses session scope so the server starts only once for the entire test session. + Multi-stage initialization can take 10-20+ minutes. + """ + model, stage_config_path = request.param + + print(f"Starting OmniServer with model: {model}") + print("This may take 10-20+ minutes for initialization...") + + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server: + print("OmniServer started successfully") + yield server + print("OmniServer stopped") + + +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_bench_serve_chat(omni_server): + command = [ + "vllm", + "bench", + "serve", + "--omni", + "--model", + omni_server.model, + "--port", + str(omni_server.port), + "--dataset-name", + "random", + "--random-input-len", + "32", + "--random-output-len", + "4", + "--num-prompts", + "5", + "--endpoint", + "/v1/chat/completions", + "--backend", + "openai-chat-omni", + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d7e6cc8a7e88f7d47798ec9eb72a202dfa49dd --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,1005 @@ +import base64 +import datetime +import io +import math +import os +import random + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +# Set CPU device for CI environments without GPU +if "VLLM_TARGET_DEVICE" not in os.environ: + os.environ["VLLM_TARGET_DEVICE"] = "cpu" + +import gc +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import Any + +import numpy as np +import psutil +import pytest +import torch +import yaml +from vllm.distributed.parallel_state import cleanup_dist_env_and_memory +from vllm.logger import init_logger +from vllm.utils.network_utils import get_open_port + +logger = init_logger(__name__) + + +@pytest.fixture(autouse=True) +def default_vllm_config(): + """Set a default VllmConfig for all tests. + + This fixture is auto-used for all tests to ensure that any test + that directly instantiates vLLM CustomOps (e.g., RMSNorm, LayerNorm) + or model components has the required VllmConfig context. + + This fixture is required for vLLM 0.14.0+ where CustomOp initialization + requires a VllmConfig context set via set_current_vllm_config(). + """ + from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config + + # Use CPU device if no GPU is available (e.g., in CI environments) + has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0 + device = "cuda" if has_gpu else "cpu" + device_config = DeviceConfig(device=device) + + with set_current_vllm_config(VllmConfig(device_config=device_config)): + yield + + +@pytest.fixture(autouse=True) +def clean_gpu_memory_between_tests(): + print("\n=== PRE-TEST GPU CLEANUP ===") + _run_pre_test_cleanup() + yield + _run_post_test_cleanup() + + +def _run_pre_test_cleanup(enable_force=False): + if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: + print("GPU cleanup disabled") + return + + print("Pre-test GPU status:") + + num_gpus = torch.cuda.device_count() + if num_gpus > 0: + try: + from tests.utils import wait_for_gpu_memory_to_clear + + wait_for_gpu_memory_to_clear( + devices=list(range(num_gpus)), + threshold_ratio=0.05, + ) + except Exception as e: + print(f"Pre-test cleanup note: {e}") + + +def _run_post_test_cleanup(enable_force=False): + if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force: + print("GPU cleanup disabled") + return + + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + + print("Post-test GPU status:") + _print_gpu_processes() + + +def _print_gpu_processes(): + """Print GPU information including nvidia-smi and system processes""" + + print("\n" + "=" * 80) + print("NVIDIA GPU Information (nvidia-smi)") + print("=" * 80) + + try: + nvidia_result = subprocess.run( + ["nvidia-smi"], + capture_output=True, + text=True, + timeout=5, + ) + + if nvidia_result.returncode == 0: + lines = nvidia_result.stdout.strip().split("\n") + for line in lines[:20]: + print(line) + + if len(lines) > 20: + print(f"... (showing first 20 of {len(lines)} lines)") + else: + print("nvidia-smi command failed") + + except (subprocess.TimeoutExpired, FileNotFoundError): + print("nvidia-smi not available or timed out") + except Exception as e: + print(f"Error running nvidia-smi: {e}") + + print("\n" + "=" * 80) + print("Detailed GPU Processes (nvidia-smi pmon)") + print("=" * 80) + + try: + pmon_result = subprocess.run( + ["nvidia-smi", "pmon", "-c", "1"], + capture_output=True, + text=True, + timeout=3, + ) + + if pmon_result.returncode == 0 and pmon_result.stdout.strip(): + print(pmon_result.stdout) + else: + print("No active GPU processes found via nvidia-smi pmon") + + except Exception: + print("nvidia-smi pmon not available") + + print("\n" + "=" * 80) + print("System Processes with GPU keywords") + print("=" * 80) + + +def dummy_messages_from_mix_data( + system_prompt: dict[str, Any] = None, + video_data_url: Any = None, + audio_data_url: Any = None, + image_data_url: Any = None, + content_text: str = None, +): + """Create messages with video、image、audio data URL for OpenAI API.""" + + if content_text is not None: + content = [{"type": "text", "text": content_text}] + else: + content = [] + + media_items = [] + if isinstance(video_data_url, list): + for video_url in video_data_url: + media_items.append((video_url, "video")) + else: + media_items.append((video_data_url, "video")) + + if isinstance(image_data_url, list): + for url in image_data_url: + media_items.append((url, "image")) + else: + media_items.append((image_data_url, "image")) + + if isinstance(audio_data_url, list): + for url in audio_data_url: + media_items.append((url, "audio")) + else: + media_items.append((audio_data_url, "audio")) + + content.extend( + {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}} + for url, media_type in media_items + if url is not None + ) + messages = [{"role": "user", "content": content}] + if system_prompt is not None: + messages = [system_prompt] + messages + return messages + + +def generate_synthetic_audio( + duration: int, # seconds + num_channels: int, # 1:Mono,2:Stereo 5:5.1 surround sound + sample_rate: int = 48000, # Default use 48000Hz. + save_to_file: bool = False, +) -> dict[str, Any]: + """ "Generate synthetic audio with rain.""" + import soundfile as sf + + # Initialize audio data array + num_samples = int(sample_rate * duration) + audio_data = np.zeros((num_samples, num_channels), dtype=np.float32) + + # Configure parameters based on rain intensity + drop_density = 10 # Number of raindrops per second + drop_volume = 0.15 # Volume of individual raindrops + background_volume = 0.02 # Volume of background rain noise + + # Pink noise sounds more natural than white noise for rain + white_noise = np.random.randn(num_samples) + pink_noise = np.convolve(white_noise, np.ones(8) / 8, mode="same") + pink_noise = pink_noise / np.max(np.abs(pink_noise)) if np.max(np.abs(pink_noise)) > 0 else pink_noise + bg_noise = pink_noise * background_volume + + # Add background noise to all channels + for ch in range(num_channels): + audio_data[:, ch] += bg_noise + + # Total number of raindrops = density × duration × channels for stereo effect + total_drops = int(drop_density * duration * num_channels) + + for _ in range(total_drops): + # Random timing for raindrop + drop_time = random.uniform(0, duration) + + # Random duration of raindrop sound (0.01-0.05 seconds) + drop_duration = random.uniform(0.01, 0.05) + + # Random frequency gives variation in raindrop pitch + drop_freq = random.uniform(500, 5000) # Hz + + # Random channel selection for stereo positioning + channel = random.randint(0, num_channels - 1) + + # Calculate sample positions for this raindrop + start_sample = int(drop_time * sample_rate) + drop_samples = int(drop_duration * sample_rate) + end_sample = min(start_sample + drop_samples, num_samples) + + if start_sample < end_sample: + # Generate the raindrop sound + num_drop_samples = end_sample - start_sample + t = np.arange(num_drop_samples) / sample_rate + + # Basic sine wave for raindrop sound + drop_sound = drop_volume * np.sin(2 * math.pi * drop_freq * t) + + # Apply envelope for natural attack and decay + envelope = np.ones(num_drop_samples) + attack_samples = int(num_drop_samples * 0.1) # 10% of samples for attack + decay_samples = num_drop_samples - attack_samples + + if attack_samples > 0: + # Linear attack: volume increases from 0 to 1 + envelope[:attack_samples] = np.linspace(0, 1, attack_samples) + + if decay_samples > 0: + # Exponential decay for natural sound fade + decay = np.exp(-8 * t[attack_samples:] / drop_duration) + envelope[attack_samples:] = decay + + # Apply envelope to raindrop sound + drop_sound *= envelope + + # Add raindrop sound to selected channel + audio_data[start_sample:end_sample, channel] += drop_sound + + # Step 3: Add simple reverb effect for realism + # Reverb simulates sound reflections in environment + if duration > 2: + # Single delay reverb (100ms delay) + delay_samples = int(0.1 * sample_rate) + if delay_samples < num_samples: + for ch in range(num_channels): + delayed = np.zeros(num_samples) + delayed[delay_samples:] = audio_data[:-delay_samples, ch] * 0.3 + audio_data[:, ch] += delayed + + # Step 4: Normalize audio to prevent clipping + # Find maximum amplitude and scale to 80% of maximum volume + max_amp = np.max(np.abs(audio_data)) + if max_amp > 0: + audio_data = audio_data / max_amp * 0.8 + + # Handle file saving + audio_bytes = None + + if save_to_file: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = f"audio_{num_channels}ch_{timestamp}.wav" + + try: + sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16") + print(f"Audio saved: {output_path}") + + with open(output_path, "rb") as f: + audio_bytes = f.read() + except Exception as e: + print(f"Save failed: {e}") + save_to_file = False + + # If not saving or save failed, create in memory + if not save_to_file or audio_bytes is None: + buffer = io.BytesIO() + sf.write(buffer, audio_data, sample_rate, format="WAV", subtype="PCM_16") + buffer.seek(0) + audio_bytes = buffer.read() + + # Return result + base64_audio = base64.b64encode(audio_bytes).decode("utf-8") + result = { + "base64": base64_audio, + } + if save_to_file and output_path: + result["file_path"] = output_path + + return result + + +def generate_synthetic_video(width: int, height: int, num_frames: int, save_to_file: bool = False) -> str: + """Generate synthetic video with bouncing balls and return base64 string.""" + + import cv2 + import imageio + + # Create random balls + num_balls = random.randint(3, 8) + balls = [] + + for _ in range(num_balls): + radius = min(width, height) // 8 + if radius < 1: + raise ValueError(f"Video dimensions ({width}x{height}) are too small for synthetic video generation") + x = random.randint(radius, width - radius) + y = random.randint(radius, height - radius) + + speed = random.uniform(3.0, 8.0) + angle = random.uniform(0, 2 * math.pi) + vx = speed * math.cos(angle) + vy = speed * math.sin(angle) + + # OpenCV uses BGR format, but imageio expects RGB + # We'll create in BGR first, then convert to RGB later + color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255)) + + balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr}) + + # Generate video frames + video_frames = [] + + for frame_idx in range(num_frames): + # Create black background (BGR format) + frame_bgr = np.zeros((height, width, 3), dtype=np.uint8) + + for ball in balls: + # Update position + ball["x"] += ball["vx"] + ball["y"] += ball["vy"] + + # Boundary collision detection + if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width: + ball["vx"] = -ball["vx"] + ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"])) + + if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height: + ball["vy"] = -ball["vy"] + ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"])) + + # Use cv2 to draw circle + x, y = int(ball["x"]), int(ball["y"]) + radius = ball["radius"] + + # Draw solid circle (main circle) + cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1) + + # Add simple 3D effect: draw a brighter center + if radius > 3: # Only add highlight when radius is large enough + highlight_radius = max(1, radius // 2) + highlight_x = max(highlight_radius, min(x - radius // 4, width - highlight_radius)) + highlight_y = max(highlight_radius, min(y - radius // 4, height - highlight_radius)) + + # Create highlight color (brighter) + highlight_color = tuple(min(c + 40, 255) for c in ball["color_bgr"]) + cv2.circle(frame_bgr, (highlight_x, highlight_y), highlight_radius, highlight_color, -1) + + # Convert BGR to RGB for imageio + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + video_frames.append(frame_rgb) + + video_bytes = None + saved_file_path = None + + buffer = io.BytesIO() + writer_kwargs = { + "format": "mp4", + "fps": 30, + "codec": "libx264", + "quality": 7, + "pixelformat": "yuv420p", + "macro_block_size": 16, + "ffmpeg_params": [ + "-preset", + "medium", + "-crf", + "23", + "-movflags", + "+faststart", + "-pix_fmt", + "yuv420p", + "-vf", + f"scale={width}:{height}", + ], + } + + if save_to_file: + import datetime + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = f"video_{width}x{height}_{timestamp}.mp4" + try: + with imageio.get_writer(output_path, **writer_kwargs) as writer: + for frame in video_frames: + writer.append_data(frame) + + saved_file_path = output_path + print(f"Video saved to: {saved_file_path}") + with open(output_path, "rb") as f: + video_bytes = f.read() + + except Exception as e: + print(f"Warning: Failed to save video to file {output_path}: {e}") + save_to_file = False + + if not save_to_file or video_bytes is None: + with imageio.get_writer(buffer, **writer_kwargs) as writer: + for frame in video_frames: + writer.append_data(frame) + + buffer.seek(0) + video_bytes = buffer.read() + + base64_video = base64.b64encode(video_bytes).decode("utf-8") + + result = { + "base64": base64_video, + } + if save_to_file and saved_file_path: + result["file_path"] = saved_file_path + + return result + + +def generate_synthetic_image(width: int, height: int, save_to_file: bool = False) -> Any: + """Generate synthetic image with randomly colored squares and return base64 string.""" + from PIL import Image, ImageDraw + + # Create white background + image = Image.new("RGB", (width, height), (255, 255, 255)) + draw = ImageDraw.Draw(image) + + # Generate random number of squares + num_squares = random.randint(3, 8) + + for _ in range(num_squares): + # Random square size + square_size = random.randint(min(width, height) // 8, min(width, height) // 4) + + # Random position + x = random.randint(0, width - square_size - 1) + y = random.randint(0, height - square_size - 1) + + # Random color + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + + # Random border width + border_width = random.randint(1, 5) + + # Draw square + draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width) + + # Handle file saving + image_bytes = None + saved_file_path = None + + if save_to_file: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = f"image_{width}x{height}_{timestamp}.jpg" + + try: + # Save image to file + image.save(output_path, format="JPEG", quality=85, optimize=True) + saved_file_path = output_path + print(f"Image saved to: {saved_file_path}") + + # Read file for base64 encoding + with open(output_path, "rb") as f: + image_bytes = f.read() + + except Exception as e: + print(f"Warning: Failed to save image to file {output_path}: {e}") + save_to_file = False + + # If not saving or save failed, create in memory + if not save_to_file or image_bytes is None: + buffer = io.BytesIO() + image.save(buffer, format="JPEG", quality=85, optimize=True) + buffer.seek(0) + image_bytes = buffer.read() + + # Generate base64 + base64_image = base64.b64encode(image_bytes).decode("utf-8") + + # Return result + result = { + "base64": base64_image, + } + if save_to_file and saved_file_path: + result["file_path"] = saved_file_path + + return result + + +def preprocess_text(text): + import re + + word_to_num = { + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + + for word, num in word_to_num.items(): + pattern = r"\b" + re.escape(word) + r"\b" + text = re.sub(pattern, num, text, flags=re.IGNORECASE) + + text = re.sub(r"[^\w\s]", "", text) + text = re.sub(r"\s+", " ", text) + return text.lower().strip() + + +def cosine_similarity_text(text1, text2, n: int = 3): + from collections import Counter + + if not text1 or not text2: + return 0.0 + + text1 = preprocess_text(text1) + text2 = preprocess_text(text2) + + ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)] + ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)] + + counter1 = Counter(ngrams1) + counter2 = Counter(ngrams2) + + all_ngrams = set(counter1.keys()) | set(counter2.keys()) + vec1 = [counter1.get(ng, 0) for ng in all_ngrams] + vec2 = [counter2.get(ng, 0) for ng in all_ngrams] + + dot_product = sum(a * b for a, b in zip(vec1, vec2)) + norm1 = sum(a * a for a in vec1) ** 0.5 + norm2 = sum(b * b for b in vec2) ** 0.5 + + if norm1 == 0 or norm2 == 0: + return 0.0 + return dot_product / (norm1 * norm2) + + +def convert_audio_to_text(audio_data): + """ + Convert base64 encoded audio data to text using speech recognition. + """ + import whisper + + audio_data = base64.b64decode(audio_data) + output_path = f"./test_{int(time.time())}" + with open(output_path, "wb") as audio_file: + audio_file.write(audio_data) + + print(f"audio data is saved: {output_path}") + + model = whisper.load_model("base") + text = model.transcribe( + output_path, + temperature=0.0, + word_timestamps=True, + condition_on_previous_text=False, + )["text"] + if text: + return text + else: + return "" + + +def merge_base64_and_convert_to_text(base64_list): + """ + Merge a list of base64 encoded audio data and convert to text. + """ + import whisper + from pydub import AudioSegment + + merged_audio = None + for base64_data in base64_list: + audio_data = base64.b64decode(base64_data.split(",", 1)[-1]) + seg = AudioSegment.from_file(io.BytesIO(audio_data)) + if merged_audio is None: + merged_audio = seg + else: + merged_audio += seg + output_path = f"./test_{int(time.time())}" + merged_audio.export(output_path, format="wav") + model = whisper.load_model("base") + text = model.transcribe( + output_path, + temperature=0.0, + word_timestamps=True, + condition_on_previous_text=False, + )["text"] + if text: + return text + else: + return "" + + +def modify_stage_config( + yaml_path: str, + updates: dict[str, Any], + deletes: dict[str, Any] = None, +) -> str: + """ + Modify configurations in a YAML file, supporting both top-level and stage-specific modifications, + including addition, modification, and deletion of configurations. + + Args: + yaml_path: Path to the YAML configuration file. + updates: Dictionary containing both top-level and stage-specific modifications to add or update. + Format: { + 'async_chunk': True, + 'stage_args': { + 0: {'engine_args.max_model_len': 5800}, + 1: {'runtime.max_batch_size': 2} + } + } + deletes: Dictionary containing configurations to delete. + Format: { + 'old_config': None, # Delete entire key + 'stage_args': { + 0: ['engine_args.old_param'], + 1: ['runtime.unused_setting'] + } + } + + Returns: + str: Path to the newly created modified YAML file with timestamp suffix. + """ + path = Path(yaml_path) + if not path.exists(): + raise FileNotFoundError(f"yaml does not exist: {path}") + + try: + with open(yaml_path, encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + except Exception as e: + raise ValueError(f"Cannot parse YAML file: {e}") + + # Helper function to apply update + def apply_update(config_dict: dict, key_path: str, value: Any) -> None: + """Apply update to dictionary using dot-separated path.""" + # Handle direct list assignment (e.g., engine_input_source: [1, 2]) + if "." not in key_path: + # Simple key, set directly + config_dict[key_path] = value + return + + current = config_dict + keys = key_path.split(".") + + for i in range(len(keys) - 1): + key = keys[i] + + # Handle list indices + if key.isdigit() and isinstance(current, list): + index = int(key) + if index < 0: + raise ValueError(f"Negative list index not allowed: {index}") + if index >= len(current): + # Expand list if needed + while len(current) <= index: + # If we need to go deeper (more keys after this), create a dict + # Otherwise, create None placeholder + current.append({} if i < len(keys) - 2 else None) + current = current[index] + elif isinstance(current, dict): + # Handle dictionary keys + if key not in current: + # If there are more keys after this, create appropriate structure + if i < len(keys) - 1: + # Check if next key is a digit (list index) or string (dict key) + if keys[i + 1].isdigit(): + current[key] = [] + else: + current[key] = {} + else: + # This is the last key, create based on value type + current[key] = [] if isinstance(value, list) else {} + elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1: + # If current value is not dict/list but we need to go deeper, replace it + if keys[i + 1].isdigit(): + current[key] = [] + else: + current[key] = {} + current = current[key] + else: + # Current is not a dict or list, cannot traverse further + raise TypeError( + f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}" + ) + + # Set the final value + last_key = keys[-1] + if isinstance(current, list) and last_key.isdigit(): + # Setting a value in a list by index + index = int(last_key) + if index < 0: + raise ValueError(f"Negative list index not allowed: {index}") + if index >= len(current): + # Expand list if needed + while len(current) <= index: + current.append(None) + current[index] = value + elif isinstance(current, dict): + # Special case: if the value is a list and we're setting a top-level key + # Example: updating engine_input_source with [1, 2] + current[last_key] = value + else: + # Current is not a dict, cannot set key + raise TypeError(f"Cannot set value at {key_path}. Current type is {type(current).__name__}, expected dict.") + + # Helper function to delete by path + def delete_by_path(config_dict: dict, path: str) -> None: + """Delete configuration by dot-separated path.""" + if not path: + return + + current = config_dict + keys = path.split(".") + + # Traverse to the parent + for i in range(len(keys) - 1): + key = keys[i] + + # Handle list indices + if key.isdigit() and isinstance(current, list): + index = int(key) + if index < 0 or index >= len(current): + raise KeyError(f"List index {index} out of bounds") + current = current[index] + elif isinstance(current, dict): + if key not in current: + raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist") + current = current[key] + else: + raise TypeError( + f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}" + ) + + # Delete the item + last_key = keys[-1] + + if isinstance(current, list) and last_key.isdigit(): + index = int(last_key) + if index < 0 or index >= len(current): + raise KeyError(f"List index {index} out of bounds") + del current[index] + elif isinstance(current, dict) and last_key in current: + del current[last_key] + else: + raise KeyError(f"Path {path} does not exist") + + # Apply deletions first + if deletes: + for key, value in deletes.items(): + if key == "stage_args": + if value and isinstance(value, dict): + stage_args = config.get("stage_args", []) + if not stage_args: + raise ValueError("stage_args does not exist in config") + + for stage_id, delete_paths in value.items(): + if not delete_paths: + continue + + # Find stage by ID + target_stage = None + for stage in stage_args: + if stage.get("stage_id") == stage_id: + target_stage = stage + break + + if target_stage is None: + available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s] + raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}") + + # Delete specified paths in this stage + for path in delete_paths: + if path: # Skip empty paths + delete_by_path(target_stage, path) + elif "." in key: + # Delete using dot-separated path + delete_by_path(config, key) + elif value is None and key in config: + # Delete entire key + del config[key] + + # Apply updates + for key, value in updates.items(): + if key == "stage_args": + if value and isinstance(value, dict): + stage_args = config.get("stage_args", []) + if not stage_args: + raise ValueError("stage_args does not exist in config") + + for stage_id, stage_updates in value.items(): + # Find stage by ID + target_stage = None + for stage in stage_args: + if stage.get("stage_id") == stage_id: + target_stage = stage + break + + if target_stage is None: + available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s] + raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}") + + # Apply updates to this stage + for path, val in stage_updates.items(): + # Check if this is a simple key (not dot-separated) + # Example: 'engine_input_source' vs 'engine_args.max_model_len' + if "." not in path: + # Direct key assignment (e.g., updating a list value) + target_stage[path] = val + else: + # Dot-separated path (e.g., nested dict access) + apply_update(target_stage, path, val) + elif "." in key: + # Apply using dot-separated path + apply_update(config, key, value) + else: + # Direct top-level key + config[key] = value + + # Save to new file with timestamp + timestamp = int(time.time()) + base_name = yaml_path.rsplit(".", 1)[0] if "." in yaml_path else yaml_path + output_path = f"{base_name}_{timestamp}.yaml" + + with open(output_path, "w", encoding="utf-8") as f: + yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2) + + return output_path + + +class OmniServer: + """Omniserver for vLLM-Omni tests.""" + + def __init__( + self, + model: str, + serve_args: list[str], + *, + env_dict: dict[str, str] | None = None, + ) -> None: + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + cleanup_dist_env_and_memory() + self.model = model + self.serve_args = serve_args + self.env_dict = env_dict + self.proc: subprocess.Popen | None = None + self.host = "127.0.0.1" + self.port = get_open_port() + + def _start_server(self) -> None: + """Start the vLLM-Omni server subprocess.""" + env = os.environ.copy() + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if self.env_dict is not None: + env.update(self.env_dict) + + cmd = [ + sys.executable, + "-m", + "vllm_omni.entrypoints.cli.main", + "serve", + self.model, + "--omni", + "--host", + self.host, + "--port", + str(self.port), + ] + self.serve_args + + print(f"Launching OmniServer with: {' '.join(cmd)}") + self.proc = subprocess.Popen( + cmd, + env=env, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root + ) + + # Wait for server to be ready + max_wait = 1200 # 20 minutes + start_time = time.time() + while time.time() - start_time < max_wait: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) + result = sock.connect_ex((self.host, self.port)) + if result == 0: + print(f"Server ready on {self.host}:{self.port}") + return + except Exception: + pass + time.sleep(2) + + raise RuntimeError(f"Server failed to start within {max_wait} seconds") + + def _kill_process_tree(self, pid): + """kill process and its children with verification""" + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + + # Get all PIDs first + all_pids = [pid] + [child.pid for child in children] + + # Terminate children + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + + # Wait for children + gone, still_alive = psutil.wait_procs(children, timeout=10) + + # Kill remaining children + for child in still_alive: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + # Terminate parent + try: + parent.terminate() + parent.wait(timeout=10) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + try: + parent.kill() + except psutil.NoSuchProcess: + pass + + # VERIFICATION: Check if all processes are gone + time.sleep(1) # Give system time + alive_processes = [] + for check_pid in all_pids: + if psutil.pid_exists(check_pid): + alive_processes.append(check_pid) + + if alive_processes: + print(f"Warning: Processes still alive: {alive_processes}") + # Optional: Try system kill + import subprocess + + for alive_pid in alive_processes: + try: + subprocess.run(["kill", "-9", str(alive_pid)], timeout=2) + except Exception as e: + print(f"Cleanup failed: {e}") + + except psutil.NoSuchProcess: + pass + + def __enter__(self): + self._start_server() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.proc: + self._kill_process_tree(self.proc.pid) + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + cleanup_dist_env_and_memory() diff --git a/tests/diffusion/attention/test_attention_sp.py b/tests/diffusion/attention/test_attention_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..447d1ea93bd0eed48e66b79b8e0ecece636e1009 --- /dev/null +++ b/tests/diffusion/attention/test_attention_sp.py @@ -0,0 +1,510 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import pickle +import tempfile + +import pytest +import torch + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import ( + DiffusionParallelConfig, + OmniDiffusionConfig, +) +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.platforms import current_omni_platform + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +class TestAttentionModel(torch.nn.Module): + """Test model using Attention layer.""" + + def __init__( + self, + num_heads: int, + head_size: int, + hidden_size: int, + causal: bool = False, + num_kv_heads: int | None = None, + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.hidden_size = hidden_size + self.attention = Attention( + num_heads=num_heads, + head_size=head_size, + causal=causal, + softmax_scale=1.0 / (head_size**0.5), + num_kv_heads=num_kv_heads, + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + # Linear projection layers for Q, K, V + self.q_proj = torch.nn.Linear(hidden_size, num_heads * head_size) + self.k_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size) + self.v_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size) + self.o_proj = torch.nn.Linear(num_heads * head_size, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass through attention layer.""" + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to (batch_size, seq_len, num_heads, head_size) + q = q.view(batch_size, seq_len, self.num_heads, self.head_size) + k = k.view(batch_size, seq_len, k.shape[-1] // self.head_size, self.head_size) + v = v.view(batch_size, seq_len, v.shape[-1] // self.head_size, self.head_size) + + # Apply attention + attn_output = self.attention(q, k, v) + + # Reshape back and project + attn_output = attn_output.view(batch_size, seq_len, -1) + output = self.o_proj(attn_output) + + return output + + +class TestMultiLayerAttentionModel(torch.nn.Module): + """Test model with multiple attention layers.""" + + def __init__( + self, + num_layers: int, + num_heads: int, + head_size: int, + hidden_size: int, + causal: bool = True, + num_kv_heads: int | None = None, + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.layers = torch.nn.ModuleList( + [ + TestAttentionModel( + num_heads=num_heads, + head_size=head_size, + hidden_size=hidden_size, + causal=causal, + num_kv_heads=num_kv_heads, + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass through multiple attention layers.""" + for layer in self.layers: + hidden_states = hidden_states + layer(hidden_states) + return hidden_states + + +@pytest.mark.parametrize( + "test_model_cls", + [ + TestMultiLayerAttentionModel, + ], +) +@pytest.mark.parametrize("ulysses_degree", [2]) +@pytest.mark.parametrize("ring_degree", [2]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [8]) +@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # [torch.float16, torch.bfloat16] +@pytest.mark.parametrize("use_sync", [False]) +@pytest.mark.parametrize("dynamic", [False]) +@pytest.mark.parametrize("use_compile", [False]) +@pytest.mark.parametrize("attn_backend", ["sdpa", "flash_attn"]) +def test_sequence_parallel( + ulysses_degree: int, + ring_degree: int, + test_model_cls: type[torch.nn.Module], + dtype: torch.dtype, + causal: bool, + use_sync: bool, + dynamic: bool, + use_compile: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + attn_backend: str, +): + """Test Ulysses attention by comparing with and without SP enabled.""" + sequence_parallel_size = ulysses_degree * ring_degree + + # Skip if not enough GPUs available + available_gpus = current_omni_platform.get_device_count() + if available_gpus < sequence_parallel_size: + pytest.skip(f"Test requires {sequence_parallel_size} GPUs but only {available_gpus} available") + + # Create temporary files to share results between processes + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + baseline_output_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + sp_output_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + model_state_file = f.name + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: + input_data_file = f.name + + try: + # Step 1: Run without SP (baseline with ulysses_degree=1, ring_degree=1) + print("\n[Baseline] Running without SP (ulysses_degree=1, ring_degree=1)...") + torch.multiprocessing.spawn( + ulysses_attention_on_test_model, + args=( + 1, # num_processes = 1 for baseline + test_model_cls, + batch_size, + seq_len, + num_heads, + head_size, + dtype, + causal, + use_sync, + dynamic, + use_compile, + 1, # ulysses_degree = 1 + 1, # ring_degree = 1 + 1, # sequence_parallel_size = 1 + baseline_output_file, + model_state_file, + input_data_file, + True, # is_baseline + attn_backend, + ), + nprocs=1, + ) + + # Step 2: Run with SP enabled + print(f"\n[SP Test] Running with SP (ulysses_degree={ulysses_degree}, ring_degree={ring_degree})...") + torch.multiprocessing.spawn( + ulysses_attention_on_test_model, + args=( + sequence_parallel_size, # num_processes + test_model_cls, + batch_size, + seq_len, + num_heads, + head_size, + dtype, + causal, + use_sync, + dynamic, + use_compile, + ulysses_degree, + ring_degree, + sequence_parallel_size, + sp_output_file, + model_state_file, + input_data_file, + False, # is_baseline + attn_backend, + ), + nprocs=sequence_parallel_size, + ) + + # Step 3: Verify input consistency and compare outputs + print(f"\n{'=' * 80}") + print("Verifying input data consistency...") + with open(input_data_file, "rb") as f: + input_data = pickle.load(f) + input_checksum = hash(input_data.tobytes()) + print(f" Input data shape: {input_data.shape}") + print(f" Input data checksum: {input_checksum}") + print(" ✓ Both baseline and SP used the same input data") + + print(f"\n{'=' * 80}") + print("Comparing outputs between baseline and SP...") + with open(baseline_output_file, "rb") as f: + baseline_output = pickle.load(f) + with open(sp_output_file, "rb") as f: + sp_output = pickle.load(f) + + # Convert to tensors for comparison + baseline_tensor = torch.tensor(baseline_output) + sp_tensor = torch.tensor(sp_output) + + print(f" Baseline output shape: {baseline_tensor.shape}") + print(f" SP output shape: {sp_tensor.shape}") + assert baseline_tensor.shape == sp_tensor.shape, "Output shapes must match!" + + # Calculate differences + abs_diff = torch.abs(baseline_tensor - sp_tensor) + max_abs_diff = abs_diff.max().item() + mean_abs_diff = abs_diff.mean().item() + + # Calculate relative difference (avoid division by zero) + baseline_abs = torch.abs(baseline_tensor) + relative_diff = abs_diff / (baseline_abs + 1e-8) + max_relative_diff = relative_diff.max().item() + mean_relative_diff = relative_diff.mean().item() + + print(f"\n{'=' * 80}") + print("Output Difference Analysis:") + print(f" - Max absolute difference: {max_abs_diff:.6e}") + print(f" - Mean absolute difference: {mean_abs_diff:.6e}") + print(f" - Max relative difference: {max_relative_diff:.6e}") + print(f" - Mean relative difference: {mean_relative_diff:.6e}") + print(f" - Baseline output range: [{baseline_tensor.min().item():.6e}, {baseline_tensor.max().item():.6e}]") + print(f" - SP output range: [{sp_tensor.min().item():.6e}, {sp_tensor.max().item():.6e}]") + print(f"{'=' * 80}\n") + + # Assert that differences are within acceptable tolerance + # For FP16/BF16, we expect some numerical differences due to different computation order under parallelism. + # If we use the same backend (e.g. Flash Attention) for both baseline and SP, differences should be smaller. + if dtype == torch.float16: + atol, rtol = 5e-2, 5e-2 # Increased tolerance for Ring Attention + elif dtype == torch.bfloat16: + atol, rtol = 5e-2, 5e-2 # Increased tolerance for Ring Attention + else: + atol, rtol = 1e-5, 1e-4 + + assert max_abs_diff < atol or max_relative_diff < rtol, ( + f"Output difference too large: max_abs_diff={max_abs_diff:.6e}, " + f"max_relative_diff={max_relative_diff:.6e}, " + f"tolerance: atol={atol}, rtol={rtol}" + ) + + print("✓ Test passed: SP output matches baseline within tolerance") + + finally: + # Clean up temporary files + for f in [baseline_output_file, sp_output_file, model_state_file, input_data_file]: + if os.path.exists(f): + os.remove(f) + + +def ulysses_attention_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + causal: bool, + use_sync: bool, + dynamic: bool, + use_compile: bool, + ulysses_degree: int, + ring_degree: int, + sequence_parallel_size: int, + output_file: str, + model_state_file: str, + input_data_file: str, + is_baseline: bool, + attn_backend: str, +): + """Run Ulysses attention test on a test model and save results for comparison.""" + # Use fixed seed for reproducibility across baseline and SP runs + RANDOM_SEED = 42 + current_omni_platform.seed_everything(RANDOM_SEED) + + mode_str = "Baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" + print(f"\n[{mode_str}] Rank {local_rank}/{world_size} - Random seed set to {RANDOM_SEED}") + + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + # Initialize distributed environment + init_distributed_environment() + + # Set up OmniDiffusionConfig with parallel config + parallel_config = DiffusionParallelConfig( + pipeline_parallel_size=1, + data_parallel_size=1, + tensor_parallel_size=1, + sequence_parallel_size=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + cfg_parallel_size=1, + ) + + od_config = OmniDiffusionConfig( + model="test_model", + dtype=dtype, + parallel_config=parallel_config, + attention_backend=attn_backend, # Set the attention backend here + ) + + # Initialize model parallel + initialize_model_parallel( + data_parallel_size=1, + cfg_parallel_size=1, + sequence_parallel_size=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + # Set the config so Attention can access it + with set_forward_context(omni_diffusion_config=od_config): + # Create model + hidden_size = num_heads * head_size + + # Create model with appropriate parameters + model_kwargs = { + "num_heads": num_heads, + "head_size": head_size, + "hidden_size": hidden_size, + "causal": causal, + "num_kv_heads": None, + "scatter_idx": 2, + "gather_idx": 1, + "use_sync": use_sync, + } + + if test_model_cls == TestMultiLayerAttentionModel: + model_kwargs["num_layers"] = 2 + + model = test_model_cls(**model_kwargs) + model = model.to(device).to(dtype) + + # For baseline: Generate and save model state and input data + # This ensures both baseline and SP use exactly the same initialization + if is_baseline and local_rank == 0: + # Save model state for reuse (before any computation) + model_state = {k: v.cpu() for k, v in model.state_dict().items()} + with open(model_state_file, "wb") as f: + pickle.dump(model_state, f) + + # Generate and save full input data with fixed seed + # Reinitialize RNG to ensure reproducibility + torch.manual_seed(42) + current_omni_platform.seed_everything(42) + full_hidden_states = torch.randn( + (batch_size, seq_len, hidden_size), + dtype=dtype, + device="cpu", + ) + with open(input_data_file, "wb") as f: + pickle.dump(full_hidden_states.detach().cpu().float().numpy(), f) + + print("[Baseline] Saved model state and input data") + + # Synchronize to ensure baseline has saved data before SP loads it + if world_size > 1: + torch.distributed.barrier() + + # IMPORTANT: Both baseline and SP load the same model state and input data + # This ensures exact same initialization and input for fair comparison + with open(model_state_file, "rb") as f: + model_state = pickle.load(f) + model.load_state_dict({k: v.to(device).to(dtype) for k, v in model_state.items()}) + + with open(input_data_file, "rb") as f: + full_hidden_states_np = pickle.load(f) + full_hidden_states = torch.from_numpy(full_hidden_states_np).to(device).to(dtype) + + print(f"[Rank {local_rank}] Loaded model state and full input data with shape {full_hidden_states.shape}") + + # Split input sequence according to sequence parallel BEFORE model forward + # Each rank gets a contiguous chunk of the sequence dimension + local_seq_len = seq_len // sequence_parallel_size + start_idx = local_rank * local_seq_len + end_idx = start_idx + local_seq_len + hidden_states = full_hidden_states[:, start_idx:end_idx, :].contiguous() + + print( + f"[Rank {local_rank}] Split input: local_seq_len={local_seq_len}, " + f"indices=[{start_idx}:{end_idx}], local_shape={hidden_states.shape}" + ) + + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) + torch._dynamo.mark_dynamic(hidden_states, 1) + + # Compile model if requested + if use_compile: + model = torch.compile(model) + + # Run forward pass with local sequence chunk + print(f"[Rank {local_rank}] Running forward pass...") + output = model(hidden_states) + print(f"[Rank {local_rank}] Forward pass completed, output shape: {output.shape}") + + # Verify output shape + assert output.shape == (batch_size, local_seq_len, hidden_size), ( + f"Output shape mismatch: expected {(batch_size, local_seq_len, hidden_size)}, got {output.shape}" + ) + + # Gather outputs from all ranks AFTER computation + if world_size > 1: + print(f"[Rank {local_rank}] Gathering outputs from all {world_size} ranks...") + # Gather all outputs to rank 0 + gathered_outputs = [torch.zeros_like(output) for _ in range(world_size)] + torch.distributed.all_gather(gathered_outputs, output) + if local_rank == 0: + # Concatenate along sequence dimension to reconstruct full sequence + full_output = torch.cat(gathered_outputs, dim=1) + print(f"[Rank 0] Gathered and concatenated outputs: {full_output.shape}") + # Verify the full output shape matches expected + assert full_output.shape == (batch_size, seq_len, hidden_size), ( + f"Gathered output shape mismatch: expected {(batch_size, seq_len, hidden_size)}, " + f"got {full_output.shape}" + ) + else: + full_output = None + else: + # For baseline (world_size=1), output is already complete + full_output = output + print(f"[Rank 0] No gather needed (world_size=1), output shape: {full_output.shape}") + + # Save output from rank 0 for comparison + if local_rank == 0: + output_np = full_output.detach().cpu().float().numpy() + with open(output_file, "wb") as f: + pickle.dump(output_np, f) + + mode_str = "baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})" + print( + f"\n[{mode_str}] ✓ Saved output with shape {full_output.shape}:\n" + f" - batch_size={batch_size}, seq_len={seq_len}\n" + f" - num_heads={num_heads}, head_size={head_size}\n" + f" - dtype={dtype}, causal={causal}, use_sync={use_sync}\n" + ) + + destroy_distributed_env() diff --git a/tests/diffusion/attention/test_flash_attn.py b/tests/diffusion/attention/test_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3862405ede037a5a23085280e38810b7206598 --- /dev/null +++ b/tests/diffusion/attention/test_flash_attn.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Test script for FlashAttention backend with padding handling. + +This script tests two main scenarios: +1. Case 1: Comparing padded vs unpadded inputs for batch_size=1 +2. Case 2: Comparing FlashAttention and SDPA backends for batch_size=2 with padding +""" + +import pytest +import torch + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl +from vllm_omni.diffusion.attention.backends.sdpa import SDPAImpl + + +def create_attention_mask(batch_size: int, seq_len: int, valid_len: int, device: torch.device) -> torch.Tensor: + """ + Create attention mask where first valid_len tokens are valid (1) and rest are padding (0). + + Args: + batch_size: Batch size + seq_len: Total sequence length (including padding) + valid_len: Number of valid (non-padded) tokens + + Returns: + Attention mask of shape (batch_size, seq_len) + """ + mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) + mask[:, :valid_len] = True + return mask + + +def pad_tensor(tensor: torch.Tensor, target_seq_len: int, pad_value: float = 0.0) -> torch.Tensor: + """ + Pad tensor along sequence dimension (dim=1). + + Args: + tensor: Input tensor of shape (batch_size, seq_len, num_heads, head_dim) + target_seq_len: Target sequence length after padding + pad_value: Value to use for padding + + Returns: + Padded tensor of shape (batch_size, target_seq_len, num_heads, head_dim) + """ + batch_size, seq_len, num_heads, head_dim = tensor.shape + if target_seq_len <= seq_len: + return tensor + + padding = torch.full( + (batch_size, target_seq_len - seq_len, num_heads, head_dim), pad_value, dtype=tensor.dtype, device=tensor.device + ) + return torch.cat([tensor, padding], dim=1) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA") +def test_padding_equivalence(): + """ + Case 1: Test that padded and unpadded inputs produce similar outputs. + + - Input A: batch_size=1, hidden_states (1, 48), encoder_hidden_states (1, 16) + Concatenated length: 64, NO attention_mask + - Input B: Same data but padded: hidden_states (1, 58), encoder_hidden_states (1, 26) + Concatenated length: 84, WITH attention_mask + + Expected: Output A and Output B should be very close. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Configuration + batch_size = 1 + hidden_seq_len = 48 + encoder_seq_len = 16 + pad_length = 10 + num_heads = 8 + head_dim = 64 + + # Initialize FlashAttention + fa_impl = FlashAttentionImpl( + num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False + ) + + # Create base tensors with random values (same for both A and B) + torch.manual_seed(42) + hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype) + encoder_hidden_states_base = torch.randn( + batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + + # ========== Input A: Unpadded, no attention mask ========== + query_a = torch.cat([hidden_states_base, encoder_hidden_states_base], dim=1) + key_a = query_a.clone() + value_a = query_a.clone() + + attn_metadata_a = AttentionMetadata(attn_mask=None) + + output_a = fa_impl.forward(query=query_a, key=key_a, value=value_a, attn_metadata=attn_metadata_a) + + # ========== Input B: Padded with attention mask ========== + hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length) + encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length) + + query_b = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1) + key_b = query_b.clone() + value_b = query_b.clone() + + # Create attention mask + attn_mask_b = torch.cat( + [ + create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device), + create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device), + ], + dim=1, + ) + + attn_metadata_b = AttentionMetadata(attn_mask=attn_mask_b) + + output_b = fa_impl.forward(query=query_b, key=key_b, value=value_b, attn_metadata=attn_metadata_b) + + # Extract non-padded portion from output_b + output_b_unpadded = torch.cat( + [ + output_b[:, :hidden_seq_len, :, :], + output_b[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], + ], + dim=1, + ) + + # Compare outputs + max_diff = torch.max(torch.abs(output_a - output_b_unpadded)).item() + mean_diff = torch.mean(torch.abs(output_a - output_b_unpadded)).item() + + print("\n=== Case 1: Padding Equivalence Test ===") + print(f"Output A shape: {output_a.shape}") + print(f"Output B shape: {output_b.shape}") + print(f"Output B unpadded shape: {output_b_unpadded.shape}") + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + + # Assert that outputs are close + # Using higher tolerance for bfloat16 + assert max_diff < 0.1, f"Max difference {max_diff} exceeds threshold 0.1" + assert mean_diff < 0.01, f"Mean difference {mean_diff} exceeds threshold 0.01" + + print("✓ Case 1 PASSED: Padded and unpadded outputs are very close!") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA") +def test_fa_vs_sdpa(): + """ + Case 2: Compare FlashAttention and SDPA backends with padding. + + - batch_size=2 + - hidden_states: (2, 48) padded to (2, 58) + - encoder_hidden_states: (2, 16) padded to (2, 26) + - Concatenated length: 84 + - Compare FA and SDPA outputs + + Expected: FA and SDPA outputs should be very close. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Configuration + batch_size = 2 + hidden_seq_len = 48 + encoder_seq_len = 16 + pad_length = 10 + num_heads = 8 + head_dim = 64 + + # Initialize both backends + fa_impl = FlashAttentionImpl( + num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False + ) + + sdpa_impl = SDPAImpl(num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False) + + # Create base tensors + torch.manual_seed(123) + hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype) + encoder_hidden_states_base = torch.randn( + batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + + # Pad tensors + hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length) + encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length) + + # Concatenate + query = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1) + key = query.clone() + value = query.clone() + + # Create attention mask + attn_mask = torch.cat( + [ + create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device), + create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device), + ], + dim=1, + ) + + attn_metadata = AttentionMetadata(attn_mask=attn_mask) + + # Run FlashAttention + output_fa = fa_impl.forward(query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata) + + # Run SDPA + # SDPA expects 4D attention mask: (batch_size, 1, seq_len, seq_len) or (batch_size, seq_len) + # For causal=False, we need to convert 2D mask to 4D + if attn_mask is not None: + # Expand mask for SDPA: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) + attn_mask_4d = attn_mask.unsqueeze(1).unsqueeze(2) + # Convert bool to float: True -> 0.0, False -> -inf + attn_mask_float = torch.zeros_like(attn_mask_4d, dtype=dtype) + attn_mask_float.masked_fill_(~attn_mask_4d, float("-inf")) + attn_metadata_sdpa = AttentionMetadata(attn_mask=attn_mask_float) + else: + attn_metadata_sdpa = AttentionMetadata(attn_mask=None) + + output_sdpa = sdpa_impl.forward( + query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata_sdpa + ) + + # Compare outputs (only compare valid regions) + output_fa_valid = torch.cat( + [ + output_fa[:, :hidden_seq_len, :, :], + output_fa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], + ], + dim=1, + ) + output_sdpa_valid = torch.cat( + [ + output_sdpa[:, :hidden_seq_len, :, :], + output_sdpa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], + ], + dim=1, + ) + + max_diff = torch.max(torch.abs(output_fa_valid - output_sdpa_valid)).item() + mean_diff = torch.mean(torch.abs(output_fa_valid - output_sdpa_valid)).item() + + print("\n=== Case 2: FA vs SDPA Comparison ===") + print(f"Batch size: {batch_size}") + print(f"FA output shape: {output_fa.shape}") + print(f"SDPA output shape: {output_sdpa.shape}") + print(f"Max absolute difference (valid region): {max_diff:.6f}") + print(f"Mean absolute difference (valid region): {mean_diff:.6f}") + + # Assert that outputs are close + # Using higher tolerance for bfloat16 and different implementations + assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold 0.01" + assert mean_diff < 0.001, f"Mean difference {mean_diff} exceeds threshold 0.001" + + print("✓ Case 2 PASSED: FA and SDPA outputs are very close!") + + +if __name__ == "__main__": + print("Running FlashAttention Padding Tests...") + print("=" * 60) + + # Try to run CUDA tests + if torch.cuda.is_available(): + try: + print("\n[Running Case 1: Padding Equivalence for FA]") + test_padding_equivalence() + except Exception as e: + print(f"✗ Case 1 failed: {e}") + import traceback + + traceback.print_exc() + + try: + print("\n[Running Case 2: FA vs SDPA]") + test_fa_vs_sdpa() + except Exception as e: + print(f"✗ Case 2 failed: {e}") + import traceback + + traceback.print_exc() + else: + raise RuntimeError("CUDA is not available") + print("\n" + "=" * 60) + print("Test suite completed!") diff --git a/tests/diffusion/cache/__init__.py b/tests/diffusion/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d7fc7fcd405c1c1f1d1b099bd65dd4504f63c40 --- /dev/null +++ b/tests/diffusion/cache/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Unit tests for cache backends (cache-dit and teacache). +""" diff --git a/tests/diffusion/cache/test_cache_backends.py b/tests/diffusion/cache/test_cache_backends.py new file mode 100644 index 0000000000000000000000000000000000000000..ed9301410cae1f8858ce8740cc7bfe2f91b9b4ab --- /dev/null +++ b/tests/diffusion/cache/test_cache_backends.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Unit tests for cache backends (cache-dit and teacache). + +This module tests the cache backend implementations: +- CacheDiTBackend: cache-dit acceleration backend +- TeaCacheBackend: TeaCache hook-based backend +- Cache selector function: get_cache_backend +- DiffusionCacheConfig: configuration dataclass +""" + +from unittest.mock import Mock, patch + +import pytest + +from vllm_omni.diffusion.cache.cache_dit_backend import ( + CacheDiTBackend, +) +from vllm_omni.diffusion.cache.selector import get_cache_backend +from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend +from vllm_omni.diffusion.data import DiffusionCacheConfig + + +class TestCacheDiTBackend: + """Test CacheDiTBackend implementation.""" + + def test_init_with_dict(self): + """Test initialization with dictionary config.""" + config_dict = {"Fn_compute_blocks": 4, "max_warmup_steps": 8} + backend = CacheDiTBackend(config_dict) + assert backend.config.Fn_compute_blocks == 4 + assert backend.config.max_warmup_steps == 8 + assert backend.enabled is False + + def test_init_with_config_object(self): + """Test initialization with DiffusionCacheConfig object.""" + config = DiffusionCacheConfig(Fn_compute_blocks=4) + backend = CacheDiTBackend(config) + assert backend.config.Fn_compute_blocks == 4 + assert backend.enabled is False + + @patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit") + def test_enable_single_transformer(self, mock_cache_dit): + """Test enabling cache-dit on single-transformer pipeline.""" + # Mock pipeline + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "DiTPipeline" + mock_transformer = Mock() + mock_pipeline.transformer = mock_transformer + + # Mock cache_dit functions + mock_cache_dit.enable_cache = Mock() + mock_cache_dit.refresh_context = Mock() + + backend = CacheDiTBackend({"Fn_compute_blocks": 2}) + backend.enable(mock_pipeline) + + # Verify cache-dit was enabled + assert backend.enabled is True + assert backend._refresh_func is not None + mock_cache_dit.enable_cache.assert_called_once() + + @patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit") + def test_refresh(self, mock_cache_dit): + """Test refreshing cache context with SCM mask policy updates when num_inference_steps changes.""" + # Mock pipeline + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "DiTPipeline" + mock_transformer = Mock() + mock_pipeline.transformer = mock_transformer + + # Mock cache_dit functions + mock_cache_dit.enable_cache = Mock() + mock_cache_dit.refresh_context = Mock() + mock_steps_mask_50 = [1, 0, 1, 0, 1] * 10 # Mock mask for 50 steps + mock_steps_mask_100 = [1, 0, 1, 0, 1] * 20 # Mock mask for 100 steps + mock_cache_dit.steps_mask = Mock(side_effect=[mock_steps_mask_50, mock_steps_mask_100]) + + # Enable cache-dit with SCM enabled (using mask policy) + config = DiffusionCacheConfig( + scm_steps_mask_policy="fast", + scm_steps_policy="dynamic", + ) + backend = CacheDiTBackend(config) + backend.enable(mock_pipeline) + + # First refresh with 50 steps + backend.refresh(mock_pipeline, num_inference_steps=50) + assert backend._last_num_inference_steps == 50 + + # Verify steps_mask was called with mask policy (not direct steps mask) + mock_cache_dit.steps_mask.assert_called_with(mask_policy="fast", total_steps=50) + assert mock_cache_dit.steps_mask.call_count == 1 + + # Verify refresh_context was called with cache_config (SCM path) + mock_cache_dit.refresh_context.assert_called_once() + call_args = mock_cache_dit.refresh_context.call_args + assert call_args[0][0] == mock_transformer + # Check that cache_config was passed (not num_inference_steps directly when SCM is enabled) + assert "cache_config" in call_args[1] + cache_config_arg = call_args[1]["cache_config"] + assert cache_config_arg is not None + + # Change num_inference_steps and refresh again + mock_cache_dit.refresh_context.reset_mock() + backend.refresh(mock_pipeline, num_inference_steps=100) + + # Verify steps_mask was called again with new num_inference_steps (using mask policy) + assert mock_cache_dit.steps_mask.call_count == 2 + # Check the last call was with 100 steps and mask policy + assert mock_cache_dit.steps_mask.call_args_list[-1].kwargs["total_steps"] == 100 + assert mock_cache_dit.steps_mask.call_args_list[-1].kwargs["mask_policy"] == "fast" + + # Verify refresh_context was called again with updated mask + mock_cache_dit.refresh_context.assert_called_once() + call_args = mock_cache_dit.refresh_context.call_args + assert call_args[0][0] == mock_transformer + assert "cache_config" in call_args[1] + assert backend._last_num_inference_steps == 100 + + +class TestTeaCacheBackend: + """Test TeaCacheBackend implementation.""" + + def test_init(self): + """Test initialization.""" + config = DiffusionCacheConfig(rel_l1_thresh=0.3) + backend = TeaCacheBackend(config) + assert backend.config.rel_l1_thresh == 0.3 + assert backend.enabled is False + + @patch("vllm_omni.diffusion.cache.teacache.backend.apply_teacache_hook") + def test_enable(self, mock_apply_hook): + """Test enabling TeaCache on pipeline.""" + # Mock pipeline + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "QwenImagePipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "QwenImageTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + config = DiffusionCacheConfig(rel_l1_thresh=0.3) + backend = TeaCacheBackend(config) + backend.enable(mock_pipeline) + + # Verify hook was applied + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + @patch("vllm_omni.diffusion.cache.teacache.backend.apply_teacache_hook") + def test_enable_with_coefficients(self, mock_apply_hook): + """Test enabling TeaCache with custom coefficients.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "QwenImagePipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "QwenImageTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + config = DiffusionCacheConfig(rel_l1_thresh=0.3, coefficients=[1.0, 0.5, 0.2, 0.1, 0.05]) + backend = TeaCacheBackend(config) + backend.enable(mock_pipeline) + + assert backend.enabled is True + mock_apply_hook.assert_called_once() + + @patch("vllm_omni.diffusion.cache.teacache.backend.apply_teacache_hook") + def test_refresh(self, mock_apply_hook): + """Test refreshing TeaCache state.""" + mock_pipeline = Mock() + mock_pipeline.__class__.__name__ = "QwenImagePipeline" + mock_transformer = Mock() + mock_transformer.__class__.__name__ = "QwenImageTransformer2DModel" + mock_pipeline.transformer = mock_transformer + + # Mock hook registry + mock_hook = Mock() + mock_registry = Mock() + mock_registry.get_hook = Mock(return_value=mock_hook) + mock_registry.reset_hook = Mock() + mock_transformer._hook_registry = mock_registry + + config = DiffusionCacheConfig() + backend = TeaCacheBackend(config) + backend.enable(mock_pipeline) + + # Test refresh + backend.refresh(mock_pipeline, num_inference_steps=50) + mock_registry.reset_hook.assert_called_once() + + +class TestCacheSelector: + """Test cache backend selector function.""" + + def test_get_cache_backend_none(self): + """Test getting None backend.""" + backend = get_cache_backend(None, None) + assert backend is None + + backend = get_cache_backend("none", None) + assert backend is None + + def test_get_cache_backend_cache_dit(self): + """Test getting cache-dit backend.""" + config_dict = {"Fn_compute_blocks": 4} + backend = get_cache_backend("cache_dit", config_dict) + assert isinstance(backend, CacheDiTBackend) + assert backend.config.Fn_compute_blocks == 4 + + def test_get_cache_backend_tea_cache(self): + """Test getting teacache backend.""" + config_dict = {"rel_l1_thresh": 0.3} + backend = get_cache_backend("tea_cache", config_dict) + assert isinstance(backend, TeaCacheBackend) + assert backend.config.rel_l1_thresh == 0.3 + + def test_get_cache_backend_invalid(self): + """Test getting invalid backend raises error.""" + with pytest.raises(ValueError, match="Unsupported cache backend"): + get_cache_backend("invalid_backend", {}) diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..24e4559de366cbd01fdeeecbe2481fd20f5251b3 --- /dev/null +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for CFG (Classifier-Free Guidance) parallel functionality. + +This test verifies that predict_noise_maybe_with_cfg produces numerically +equivalent results with and without CFG parallel using fixed random inputs. +""" + +import os + +import pytest +import torch + +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.platforms import current_omni_platform + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +class SimpleTransformer(torch.nn.Module): + """Simple transformer model for testing with random initialization. + + Contains: + - Input projection (conv to hidden_dim) + - QKV projection layers + - Self-attention layer + - Output projection + """ + + def __init__(self, in_channels: int = 4, hidden_dim: int = 128, num_heads: int = 8): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" + + # Input projection: (B, C, H, W) -> (B, hidden_dim, H, W) + self.input_proj = torch.nn.Conv2d(in_channels, hidden_dim, 1) + + # QKV projection layers + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim) + + # Output projection after attention + self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim) + + # Final output projection: (B, hidden_dim, H, W) -> (B, C, H, W) + self.final_proj = torch.nn.Conv2d(hidden_dim, in_channels, 1) + + # Layer norm + self.norm1 = torch.nn.LayerNorm(hidden_dim) + self.norm2 = torch.nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: + """Forward pass with self-attention. + + Args: + x: Input tensor of shape (B, C, H, W) + + Returns: + Output tensor of shape (B, C, H, W) + """ + B, C, H, W = x.shape + + # Input projection + x = self.input_proj(x) # (B, hidden_dim, H, W) + + # Reshape to sequence: (B, hidden_dim, H, W) -> (B, H*W, hidden_dim) + x = x.flatten(2).transpose(1, 2) # (B, H*W, hidden_dim) + + # Self-attention with residual connection + residual = x + x = self.norm1(x) + + # QKV projection + q = self.q_proj(x) # (B, H*W, hidden_dim) + k = self.k_proj(x) # (B, H*W, hidden_dim) + v = self.v_proj(x) # (B, H*W, hidden_dim) + + # Reshape for multi-head attention: (B, H*W, hidden_dim) -> (B, num_heads, H*W, head_dim) + seq_len = H * W + q = q.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Scaled dot-product attention + scale = self.head_dim**-0.5 + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (B, num_heads, H*W, H*W) + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous().view(B, seq_len, self.hidden_dim) + + attn_output = self.out_proj(attn_output) + + x = residual + attn_output + residual = x + x = self.norm2(x) + x = residual + x + x = x.transpose(1, 2).view(B, self.hidden_dim, H, W) + + out = self.final_proj(x) + + return (out,) + + +class TestCFGPipeline(CFGParallelMixin): + """Test pipeline using CFGParallelMixin.""" + + def __init__(self, in_channels: int = 4, hidden_dim: int = 128, seed: int = 42): + # Set seed BEFORE creating transformer to ensure consistent layer initialization + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + self.transformer = SimpleTransformer(in_channels, hidden_dim) + + # Re-initialize all parameters with fixed seed for full reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + for param in self.transformer.parameters(): + torch.nn.init.normal_(param, mean=0.0, std=0.02) + + +def _test_cfg_parallel_worker( + local_rank: int, + world_size: int, + cfg_parallel_size: int, + dtype: torch.dtype, + test_config: dict, + result_queue: torch.multiprocessing.Queue, +): + """Worker function for CFG parallel test.""" + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29502", + } + ) + + init_distributed_environment() + initialize_model_parallel(cfg_parallel_size=cfg_parallel_size) + + cfg_rank = get_classifier_free_guidance_rank() + cfg_world_size = get_classifier_free_guidance_world_size() + + assert cfg_world_size == cfg_parallel_size + + # Create pipeline with same seed to ensure identical model weights across all ranks + # Note: model_seed is set inside TestCFGPipeline.__init__ + pipeline = TestCFGPipeline( + in_channels=test_config["channels"], + hidden_dim=test_config["hidden_dim"], + seed=test_config["model_seed"], + ) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() # Set to eval mode for deterministic behavior + + # Create fixed inputs with explicit seed setting for reproducibility + # Set both CPU and CUDA seeds to ensure identical inputs across all ranks + torch.manual_seed(test_config["input_seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"]) + + batch_size = test_config["batch_size"] + channels = test_config["channels"] + height = test_config["height"] + width = test_config["width"] + + # Positive input + positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Negative input with different seed + torch.manual_seed(test_config["input_seed"] + 1) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"] + 1) + negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Prepare kwargs for predict_noise_maybe_with_cfg + positive_kwargs = {"x": positive_input} + negative_kwargs = {"x": negative_input} + + with torch.no_grad(): + # Call predict_noise_maybe_with_cfg + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=test_config["cfg_scale"], + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=test_config["cfg_normalize"], + ) + + # Only rank 0 has valid output in CFG parallel mode + if cfg_rank == 0: + assert noise_pred is not None + result_queue.put(noise_pred.cpu()) + else: + assert noise_pred is None + + destroy_distributed_env() + + +def _test_cfg_sequential_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + test_config: dict, + result_queue: torch.multiprocessing.Queue, +): + """Worker function for sequential CFG test (baseline).""" + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29503", + } + ) + + init_distributed_environment() + initialize_model_parallel(cfg_parallel_size=1) # No CFG parallel + + cfg_world_size = get_classifier_free_guidance_world_size() + assert cfg_world_size == 1 + + # Create pipeline with same seed to ensure identical model weights as CFG parallel + # Note: model_seed is set inside TestCFGPipeline.__init__ + pipeline = TestCFGPipeline( + in_channels=test_config["channels"], + hidden_dim=test_config["hidden_dim"], + seed=test_config["model_seed"], + ) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() + + # Create fixed inputs (same seed as CFG parallel to ensure identical inputs) + # Set both CPU and CUDA seeds for full reproducibility + torch.manual_seed(test_config["input_seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"]) + + batch_size = test_config["batch_size"] + channels = test_config["channels"] + height = test_config["height"] + width = test_config["width"] + + # Positive input + positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Negative input with different seed + torch.manual_seed(test_config["input_seed"] + 1) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"] + 1) + negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + positive_kwargs = {"x": positive_input} + negative_kwargs = {"x": negative_input} + + with torch.no_grad(): + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=test_config["cfg_scale"], + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=test_config["cfg_normalize"], + ) + + # Sequential CFG always returns output + assert noise_pred is not None + result_queue.put(noise_pred.cpu()) + + destroy_distributed_env() + + +@pytest.mark.parametrize("cfg_parallel_size", [2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("cfg_normalize", [False, True]) +def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype, batch_size: int, cfg_normalize: bool): + """ + Test that predict_noise_maybe_with_cfg produces identical results + with and without CFG parallel. + + Args: + cfg_parallel_size: Number of GPUs for CFG parallel + dtype: Data type for computation + batch_size: Batch size for testing + cfg_normalize: Whether to normalize CFG output + """ + available_gpus = current_omni_platform.get_device_count() + if available_gpus < cfg_parallel_size: + pytest.skip(f"Test requires {cfg_parallel_size} GPUs but only {available_gpus} available") + + test_config = { + "batch_size": batch_size, + "channels": 4, + "height": 16, + "width": 16, + "hidden_dim": 128, + "cfg_scale": 7.5, + "cfg_normalize": cfg_normalize, + "model_seed": 42, # Fixed seed for model initialization + "input_seed": 123, # Fixed seed for input generation + } + + mp_context = torch.multiprocessing.get_context("spawn") + + manager = mp_context.Manager() + baseline_queue = manager.Queue() + cfg_parallel_queue = manager.Queue() + + # Run baseline (sequential CFG) on single GPU + torch.multiprocessing.spawn( + _test_cfg_sequential_worker, + args=(1, dtype, test_config, baseline_queue), + nprocs=1, + ) + + # Run CFG parallel on multiple GPUs + torch.multiprocessing.spawn( + _test_cfg_parallel_worker, + args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config, cfg_parallel_queue), + nprocs=cfg_parallel_size, + ) + + # Get results from queues + baseline_output = baseline_queue.get() + cfg_parallel_output = cfg_parallel_queue.get() + + # Verify shapes match + assert baseline_output.shape == cfg_parallel_output.shape, ( + f"Shape mismatch: baseline {baseline_output.shape} vs CFG parallel {cfg_parallel_output.shape}" + ) + + # Verify numerical equivalence with appropriate tolerances + if dtype == torch.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == torch.bfloat16: + rtol, atol = 1e-2, 1e-2 + else: + rtol, atol = 1e-3, 1e-3 + + torch.testing.assert_close( + cfg_parallel_output, + baseline_output, + rtol=rtol, + atol=atol, + msg=( + f"CFG parallel output differs from sequential CFG\n" + f" dtype={dtype}, batch_size={batch_size}, cfg_normalize={cfg_normalize}\n" + f" Max diff: {(cfg_parallel_output - baseline_output).abs().max().item():.6e}" + ), + ) + + print( + f"✓ Test passed: cfg_size={cfg_parallel_size}, dtype={dtype}, " + f"batch_size={batch_size}, cfg_normalize={cfg_normalize}" + ) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_predict_noise_without_cfg(dtype: torch.dtype): + """ + Test predict_noise_maybe_with_cfg when do_true_cfg=False. + + When CFG is disabled, only the positive branch should be computed. + This test runs on a single GPU without distributed environment. + """ + available_gpus = current_omni_platform.get_device_count() + if available_gpus < 1: + pytest.skip("Test requires at least 1 GPU") + + device = torch.device(f"{current_omni_platform.device_type}:0") + current_omni_platform.set_device(device) + + # Create pipeline without distributed environment + pipeline = TestCFGPipeline(in_channels=4, hidden_dim=128, seed=42) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() + + # Set seed for input generation + torch.manual_seed(123) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(123) + positive_input = torch.randn(1, 4, 16, 16, dtype=dtype, device=device) + + with torch.no_grad(): + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=False, # No CFG + true_cfg_scale=7.5, + positive_kwargs={"x": positive_input}, + negative_kwargs=None, + cfg_normalize=False, + ) + + # Should always return output when do_true_cfg=False + assert noise_pred is not None + assert noise_pred.shape == (1, 4, 16, 16) + + print(f"✓ Test passed: predict_noise without CFG (dtype={dtype})") diff --git a/tests/diffusion/distributed/test_comm.py b/tests/diffusion/distributed/test_comm.py new file mode 100644 index 0000000000000000000000000000000000000000..3e408c10a6bee786e51228c41795591608097106 --- /dev/null +++ b/tests/diffusion/distributed/test_comm.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for SeqAllToAll4D and SeqAllToAll5D communication primitives.""" + +import os + +import pytest +import torch + +from vllm_omni.diffusion.distributed.comm import RingComm, SeqAllToAll4D, SeqAllToAll5D +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + get_sp_group, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.platforms import current_omni_platform + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len_per_rank", [8]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [32]) +@pytest.mark.parametrize("use_sync", [False, True]) +def test_4d_identity( + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Test that two consecutive all-to-all operations return the original input.""" + # Skip if not enough GPUs available + available_gpus = current_omni_platform.get_device_count() + if available_gpus < world_size: + pytest.skip(f"Test requires {world_size} GPUs but only {available_gpus} available") + + # Ensure num_heads is divisible by world_size + if num_heads % world_size != 0: + pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})") + + # Run test with multiprocessing spawn + torch.multiprocessing.spawn( + _test_4d_identity_worker, + args=( + world_size, + dtype, + batch_size, + seq_len_per_rank, + num_heads, + head_size, + use_sync, + ), + nprocs=world_size, + ) + + +def _test_4d_identity_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Worker function for test_4d_identity.""" + # Set device + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + + # Set environment variables for distributed training + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29500", + } + ) + + # Initialize distributed environment + init_distributed_environment() + initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default + sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group + + # Create input tensor: (bs, seqlen/P, hc, hs) + torch.manual_seed(42 + local_rank) + input_tensor = torch.randn( + batch_size, + seq_len_per_rank, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + + # Save original input for comparison + original_input = input_tensor.clone() + + # First all-to-all: (bs, seqlen/P, hc, hs) -> (bs, seqlen, hc/P, hs) + intermediate = SeqAllToAll4D.apply( + sp_group, + input_tensor, + 2, # scatter head dimension + 1, # gather sequence dimension + use_sync, + ) + + # Verify intermediate shape + expected_shape = ( + batch_size, + seq_len_per_rank * world_size, + num_heads // world_size, + head_size, + ) + assert intermediate.shape == expected_shape, ( + f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}" + ) + + # Second all-to-all: (bs, seqlen, hc/P, hs) -> (bs, seqlen/P, hc, hs) + output = SeqAllToAll4D.apply( + sp_group, + intermediate, + 1, # scatter sequence dimension + 2, # gather head dimension + use_sync, + ) + + # Verify output shape matches input + assert output.shape == original_input.shape, ( + f"Output shape mismatch: expected {original_input.shape}, got {output.shape}" + ) + + # Verify output matches original input + torch.testing.assert_close( + output, + original_input, + rtol=1e-5, + atol=1e-5, + msg="Output does not match original input after two all-to-all operations", + ) + + # Cleanup distributed environment + destroy_distributed_env() + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seq_len_per_rank", [8]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [32]) +@pytest.mark.parametrize("use_sync", [False, True]) +def test_5d_identity( + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Test that two consecutive all-to-all operations return the original input.""" + # Skip if not enough GPUs available + available_gpus = current_omni_platform.get_device_count() + if available_gpus < world_size: + pytest.skip(f"Test requires {world_size} GPUs but only {available_gpus} available") + + # Ensure num_heads is divisible by world_size + if num_heads % world_size != 0: + pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})") + + # Run test with multiprocessing spawn + torch.multiprocessing.spawn( + _test_5d_identity_worker, + args=( + world_size, + dtype, + batch_size, + seq_len_per_rank, + num_heads, + head_size, + use_sync, + ), + nprocs=world_size, + ) + + +def _test_5d_identity_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + seq_len_per_rank: int, + num_heads: int, + head_size: int, + use_sync: bool, +): + """Worker function for test_5d_identity.""" + # Set device + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + + # Set environment variables for distributed training + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29500", + } + ) + + # Initialize distributed environment + init_distributed_environment() + initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default + sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group + + # Create input tensor: (bs, seqlen/P, 3, hc, hs) + # The '3' dimension is for Q, K, V + torch.manual_seed(42 + local_rank) + input_tensor = torch.randn( + batch_size, + seq_len_per_rank, + 3, # Q, K, V + num_heads, + head_size, + dtype=dtype, + device=device, + ) + + # Save original input for comparison + original_input = input_tensor.clone() + + # First all-to-all: (bs, seqlen/P, 3, hc, hs) -> (bs, seqlen, 3, hc/P, hs) + intermediate = SeqAllToAll5D.apply( + sp_group, + input_tensor, + 3, # scatter head dimension + 1, # gather sequence dimension + use_sync, + ) + + # Verify intermediate shape + expected_shape = ( + batch_size, + seq_len_per_rank * world_size, + 3, + num_heads // world_size, + head_size, + ) + assert intermediate.shape == expected_shape, ( + f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}" + ) + + # Second all-to-all: (bs, seqlen, 3, hc/P, hs) -> (bs, seqlen/P, 3, hc, hs) + output = SeqAllToAll5D.apply( + sp_group, + intermediate, + 1, # scatter sequence dimension + 3, # gather head dimension + use_sync, + ) + + # Verify output shape matches input + assert output.shape == original_input.shape, ( + f"Output shape mismatch: expected {original_input.shape}, got {output.shape}" + ) + + # Verify output matches original input + torch.testing.assert_close( + output, + original_input, + rtol=1e-5, + atol=1e-5, + msg="Output does not match original input after two all-to-all operations", + ) + + # Cleanup distributed environment + destroy_distributed_env() + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [128]) +def test_ring_p2p( + world_size: int, + dtype: torch.dtype, + batch_size: int, + num_heads: int, + head_size: int, +): + """Test Ring P2P communication (send_recv).""" + # Skip if not enough GPUs available + available_gpus = current_omni_platform.get_device_count() + if available_gpus < world_size: + pytest.skip(f"Test requires {world_size} GPUs but only {available_gpus} available") + + torch.multiprocessing.spawn( + _test_ring_p2p_worker, + args=(world_size, dtype, batch_size, num_heads, head_size), + nprocs=world_size, + ) + + +def _test_ring_p2p_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + batch_size: int, + num_heads: int, + head_size: int, +): + """Worker for Ring P2P test.""" + import sys + + # Set device + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + + # Set env vars + # Use a different port to avoid conflict with other tests if run in parallel + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29501", + } + ) + + # Init distributed + try: + init_distributed_environment() + # Ring degree = world_size to test ring group + initialize_model_parallel(ring_degree=world_size) + sp_group = get_sp_group() + + print(f"[Rank {local_rank}] Initialized. Ring group size: {sp_group.ring_group.size()}") + sys.stdout.flush() + + # Create RingComm + comm = RingComm(sp_group.ring_group) + + # Create tensor: rank-specific data + # (batch, num_heads, head_size) + # Fill with rank value + 1 to avoid 0 and make verification easy + input_tensor = torch.full( + (batch_size, num_heads, head_size), fill_value=float(local_rank + 1), dtype=dtype, device=device + ) + + print(f"[Rank {local_rank}] Input sum: {input_tensor.sum().item()}") + sys.stdout.flush() + + # Send input, receive from prev + # RingComm.send_recv sends to next, receives from prev + t0 = __import__("time").time() + recv_tensor = comm.send_recv(input_tensor) + comm.commit() + comm.wait() + t1 = __import__("time").time() + + print(f"[Rank {local_rank}] Communication done in {t1 - t0:.4f}s") + + # Verify + # Expected value: from (rank - 1) % world_size + prev_rank = (local_rank - 1 + world_size) % world_size + expected_value = float(prev_rank + 1) + + recv_sum = recv_tensor.sum().item() + print(f"[Rank {local_rank}] Received sum: {recv_sum}, Expected value: {expected_value}") + sys.stdout.flush() + + expected_tensor = torch.full_like(recv_tensor, fill_value=expected_value) + + # Use a slightly loose tolerance for bfloat16 + torch.testing.assert_close( + recv_tensor, expected_tensor, rtol=1e-3, atol=1e-3, msg=f"[Rank {local_rank}] Data mismatch!" + ) + print(f"[Rank {local_rank}] Verification PASSED") + + except Exception as e: + print(f"[Rank {local_rank}] FAILED with error: {e}") + import traceback + + traceback.print_exc() + raise e + finally: + destroy_distributed_env() diff --git a/tests/diffusion/distributed/test_sp_plan_hooks.py b/tests/diffusion/distributed/test_sp_plan_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..6883ae09f0caa04e09997b49d6df824c21236d56 --- /dev/null +++ b/tests/diffusion/distributed/test_sp_plan_hooks.py @@ -0,0 +1,1022 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the Sequence Parallelism (SP) framework. + +These tests verify the SP plan mechanism and hooks work correctly without +requiring a distributed environment. They test: +1. _sp_plan validation (sp_plan.py) +2. Hook utilities and submodule resolution (sequence_parallel.py) +3. Model _sp_plan definitions +4. Tensor sharding simulation + +Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) +in diffusers. We use "Sequence Parallelism" to align with vLLM-Omni terminology. +""" + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, + SequenceParallelPartialInput, + get_sp_plan_from_model, + validate_sp_plan, +) + + +def is_distributed_initialized() -> bool: + """Check if distributed environment is initialized.""" + try: + from vllm_omni.diffusion.distributed.parallel_state import get_sp_group + + get_sp_group() + return True + except (AssertionError, ImportError): + return False + + +# Decorator to skip tests that require distributed environment +requires_distributed = pytest.mark.skipif( + not is_distributed_initialized(), + reason="Requires initialized distributed environment (SP group)", +) + +# Module-level markers: these tests are diffusion + parallel related +pytestmark = [ + pytest.mark.diffusion, + pytest.mark.parallel, +] + +# ============================================================================= +# Tests for sp_plan.py +# ============================================================================= + + +@pytest.mark.cpu +class TestSequenceParallelPlanValidation: + """Test _sp_plan validation logic.""" + + def test_valid_simple_plan(self): + """Test a simple valid _sp_plan.""" + plan = { + "rope": { + 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), + 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), + }, + "blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + # Should not raise + validate_sp_plan(plan) + + def test_valid_partial_input_plan(self): + """Test a valid _sp_plan with SequenceParallelPartialInput.""" + plan = { + "pos_embed": { + 0: SequenceParallelPartialInput( + split_dim=0, + text_len_source="txt_ids", + expected_dims=2, + split_output=True, + ), + }, + "blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + } + # Should not raise + validate_sp_plan(plan) + + def test_invalid_plan_type(self): + """Test that non-dict plan raises error.""" + with pytest.raises(ValueError, match="must be a dict"): + validate_sp_plan("not a dict") + + def test_invalid_module_key_type(self): + """Test that non-string module keys raise error.""" + plan = {123: {"hidden_states": SequenceParallelInput(split_dim=1)}} + with pytest.raises(ValueError, match="keys must be strings"): + validate_sp_plan(plan) + + def test_invalid_output_index_without_split_output(self): + """Test that integer keys require split_output=True.""" + plan = { + "rope": { + 0: SequenceParallelInput(split_dim=1, split_output=False), # Invalid + } + } + with pytest.raises(ValueError, match="split_output=True"): + validate_sp_plan(plan) + + +@pytest.mark.cpu +class TestGetSpPlanFromModel: + """Test get_sp_plan_from_model utility.""" + + def test_model_with_sp_plan(self): + """Test getting _sp_plan from a model that has one.""" + + class ModelWithPlan(nn.Module): + _sp_plan = { + "layer": { + "x": SequenceParallelInput(split_dim=1), + } + } + + model = ModelWithPlan() + plan = get_sp_plan_from_model(model) + assert plan is not None + assert "layer" in plan + + def test_model_without_sp_plan(self): + """Test getting _sp_plan from a model without one.""" + + class ModelWithoutPlan(nn.Module): + pass + + model = ModelWithoutPlan() + plan = get_sp_plan_from_model(model) + assert plan is None + + +@pytest.mark.cpu +class TestSequenceParallelInputTypes: + """Test SequenceParallelInput and related types.""" + + def test_sequence_parallel_input_repr(self): + """Test SequenceParallelInput repr.""" + spi = SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True) + assert "split_dim=1" in repr(spi) + assert "expected_dims=3" in repr(spi) + assert "split_output=True" in repr(spi) + + def test_sequence_parallel_output_repr(self): + """Test SequenceParallelOutput repr.""" + spo = SequenceParallelOutput(gather_dim=1, expected_dims=3) + assert "gather_dim=1" in repr(spo) + assert "expected_dims=3" in repr(spo) + + def test_sequence_parallel_partial_input_repr(self): + """Test SequenceParallelPartialInput repr.""" + sppi = SequenceParallelPartialInput( + split_dim=0, + text_len_source="txt_ids", + expected_dims=2, + split_output=True, + ) + assert "split_dim=0" in repr(sppi) + assert "txt_ids" in repr(sppi) + assert "expected_dims=2" in repr(sppi) + assert "split_output=True" in repr(sppi) + + def test_sequence_parallel_partial_input_with_int_source(self): + """Test SequenceParallelPartialInput with integer text_len_source.""" + sppi = SequenceParallelPartialInput( + split_dim=0, + text_len_source=512, # Fixed length + expected_dims=2, + ) + assert sppi.text_len_source == 512 + + +# ============================================================================= +# Tests for sequence_parallel.py +# ============================================================================= + + +@pytest.mark.cpu +class TestModuleForwardMetadata: + """Test ModuleForwardMetadata parameter resolution.""" + + def test_get_parameter_from_kwargs(self): + """Test getting parameter from kwargs.""" + from vllm_omni.diffusion.hooks.sequence_parallel import ModuleForwardMetadata + + class DummyModule(nn.Module): + def forward(self, hidden_states, encoder_hidden_states): + pass + + metadata = ModuleForwardMetadata() + metadata._cls = DummyModule + + kwargs = {"hidden_states": torch.randn(2, 4, 8)} + val, is_kwarg, index = metadata._get_parameter_from_args_kwargs("hidden_states", (), kwargs) + assert is_kwarg is True + assert index is None + assert val.shape == (2, 4, 8) + + def test_get_parameter_from_args(self): + """Test getting parameter from positional args.""" + from vllm_omni.diffusion.hooks.sequence_parallel import ModuleForwardMetadata + + class DummyModule(nn.Module): + def forward(self, hidden_states, encoder_hidden_states): + pass + + metadata = ModuleForwardMetadata() + metadata._cls = DummyModule + + tensor = torch.randn(2, 4, 8) + args = (tensor,) + val, is_kwarg, index = metadata._get_parameter_from_args_kwargs("hidden_states", args, {}) + assert is_kwarg is False + assert index == 0 + assert torch.equal(val, tensor) + + def test_parameter_caching(self): + """Test that parameter indices are cached.""" + from vllm_omni.diffusion.hooks.sequence_parallel import ModuleForwardMetadata + + class DummyModule(nn.Module): + def forward(self, a, b, c): + pass + + metadata = ModuleForwardMetadata() + metadata._cls = DummyModule + + # First call - should populate cache + args = (torch.randn(1), torch.randn(1), torch.randn(1)) + metadata._get_parameter_from_args_kwargs("b", args, {}) + + # Check cache was populated + assert metadata.cached_parameter_indices is not None + assert metadata.cached_parameter_indices["a"] == 0 + assert metadata.cached_parameter_indices["b"] == 1 + assert metadata.cached_parameter_indices["c"] == 2 + + +@pytest.mark.cpu +class TestGetSubmoduleByName: + """Test _get_submodule_by_name function.""" + + def test_root_module(self): + """Test getting root module with empty string.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + model = nn.Linear(10, 10) + submodule = _get_submodule_by_name(model, "") + assert submodule is model + + def test_simple_submodule(self): + """Test getting a simple submodule.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 10) + + model = Model() + submodule = _get_submodule_by_name(model, "layer") + assert submodule is model.layer + + def test_nested_submodule(self): + """Test getting a nested submodule.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.encoder = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) + + model = Model() + submodule = _get_submodule_by_name(model, "encoder.0") + assert isinstance(submodule, nn.Linear) + + def test_module_list_by_index(self): + """Test getting element from ModuleList by index.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) + + model = Model() + submodule = _get_submodule_by_name(model, "blocks.0") + assert submodule is model.blocks[0] + + submodule = _get_submodule_by_name(model, "blocks.2") + assert submodule is model.blocks[2] + + def test_wildcard_modulelist(self): + """Test wildcard matching for ModuleList.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) + + model = Model() + submodules = _get_submodule_by_name(model, "blocks.*") + assert isinstance(submodules, list) + assert len(submodules) == 3 + for i, sm in enumerate(submodules): + assert sm is model.blocks[i] + + def test_module_dict(self): + """Test getting submodule from ModuleDict.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.outputs = nn.ModuleDict({"main": nn.Linear(10, 10), "aux": nn.Linear(10, 5)}) + + model = Model() + submodule = _get_submodule_by_name(model, "outputs.main") + assert submodule is model.outputs["main"] + + submodule = _get_submodule_by_name(model, "outputs.aux") + assert submodule is model.outputs["aux"] + + def test_invalid_submodule_raises(self): + """Test that invalid submodule path raises error.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 10) + + model = Model() + with pytest.raises(ValueError, match="not a submodule"): + _get_submodule_by_name(model, "nonexistent") + + def test_multiple_wildcards_raises(self): + """Test that multiple wildcards raise error.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + model = nn.Linear(10, 10) + with pytest.raises(ValueError, match="only be used once"): + _get_submodule_by_name(model, "a.*.b.*") + + +@pytest.mark.cpu +class TestHookRegistration: + """Test hook registration logic (without distributed backend).""" + + def test_plan_validation_before_apply(self): + """Test that invalid plans are rejected before hook registration.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.proj_in = nn.Linear(10, 10) + self.proj_out = nn.Linear(10, 10) + + def forward(self, x): + return self.proj_out(self.proj_in(x)) + + # Invalid plan (non-string key) + invalid_plan = { + 123: {"x": SequenceParallelInput(split_dim=1)}, + } + + with pytest.raises(ValueError): + validate_sp_plan(invalid_plan) + + def test_valid_plan_structure_for_model(self): + """Test that a valid plan can be defined for a model.""" + + class SimpleModel(nn.Module): + _sp_plan = { + "proj_in": {"x": SequenceParallelInput(split_dim=1, expected_dims=3)}, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + def __init__(self): + super().__init__() + self.proj_in = nn.Linear(10, 10) + self.proj_out = nn.Linear(10, 10) + + def forward(self, x): + return self.proj_out(self.proj_in(x)) + + model = SimpleModel() + plan = get_sp_plan_from_model(model) + + assert plan is not None + assert "proj_in" in plan + assert "proj_out" in plan + + # Verify submodules exist + from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name + + assert _get_submodule_by_name(model, "proj_in") is model.proj_in + assert _get_submodule_by_name(model, "proj_out") is model.proj_out + + +# ============================================================================= +# Tests for model _sp_plan definitions +# ============================================================================= + + +@pytest.mark.L4 +class TestModelSpPlans: + """Test that model _sp_plan definitions are valid. + + These tests import actual model classes to verify _sp_plan structure. + May require GPU for model imports. + """ + + def test_zimage_transformer_sp_plan(self): + """Test ZImageTransformer2DModel _sp_plan structure. + + The plan specifies: + - unified_prepare: Shard all 4 outputs (unified, cos, sin, attn_mask) + - all_final_layer.2-1: Gather outputs after final layer + + Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism) + """ + try: + from vllm_omni.diffusion.models.z_image.z_image_transformer import ZImageTransformer2DModel + + plan = getattr(ZImageTransformer2DModel, "_sp_plan", None) + assert plan is not None, "ZImageTransformer2DModel should define _sp_plan" + assert isinstance(plan, dict) + + assert "unified_prepare" in plan + unified_prepare_plan = plan["unified_prepare"] + # Check all 4 outputs are sharded with split_output=True + assert 0 in unified_prepare_plan # unified + assert 1 in unified_prepare_plan # unified_cos + assert 2 in unified_prepare_plan # unified_sin + assert 3 in unified_prepare_plan # unified_attn_mask + + # Check output gathering + assert "all_final_layer.2-1" in plan + + validate_sp_plan(plan) + except ImportError: + pytest.skip("ZImageTransformer2DModel not available") + + def test_qwen_image_transformer_sp_plan(self): + """Test QwenImageTransformer2DModel _sp_plan structure. + + Qwen-Image follows the diffusers pattern similar to Z-Image: + - image_rope_prepare: Shards hidden_states and vid_freqs together + - proj_out: Gathers output + + Key insight: hidden_states and vid_freqs MUST be sharded together + to maintain dimension alignment for RoPE computation. + + Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism) + """ + try: + from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + QwenImageTransformer2DModel, + ) + + plan = getattr(QwenImageTransformer2DModel, "_sp_plan", None) + assert plan is not None, "QwenImageTransformer2DModel should define _sp_plan" + assert isinstance(plan, dict) + + # Check image_rope_prepare sharding + assert "image_rope_prepare" in plan + rope_plan = plan["image_rope_prepare"] + # hidden_states (index 0) + assert 0 in rope_plan + assert rope_plan[0].split_dim == 1 + assert rope_plan[0].split_output is True + # vid_freqs (index 1) + assert 1 in rope_plan + assert rope_plan[1].split_dim == 0 + assert rope_plan[1].split_output is True + # txt_freqs (index 2) should NOT be in plan (kept replicated) + assert 2 not in rope_plan + + # Check output gathering at proj_out + assert "proj_out" in plan + proj_out_plan = plan["proj_out"] + assert proj_out_plan.gather_dim == 1 + + validate_sp_plan(plan) + except ImportError: + pytest.skip("QwenImageTransformer2DModel not available") + + +# ============================================================================= +# Tests for tensor sharding simulation (no distributed required) +# ============================================================================= + + +@pytest.mark.cpu +class TestMockSharding: + """Test tensor sharding logic (mocked, no distributed).""" + + def test_shard_tensor_simulation(self): + """Simulate tensor sharding without distributed backend.""" + # Create a test tensor + batch_size, seq_len, hidden_dim = 2, 16, 64 + tensor = torch.randn(batch_size, seq_len, hidden_dim) + + # Simulate sharding for world_size=4 + world_size = 4 + rank = 1 + + # Manual chunking (what sp_shard does internally) + chunks = tensor.chunk(world_size, dim=1) + sharded = chunks[rank] + + assert sharded.shape == (batch_size, seq_len // world_size, hidden_dim) + assert sharded.shape == (2, 4, 64) + + def test_partial_shard_simulation(self): + """Simulate partial sharding (text kept, image sharded).""" + # Create a test tensor with [text, image] concatenated + batch_size = 2 + text_len = 8 + image_len = 16 + hidden_dim = 64 + + text_part = torch.randn(batch_size, text_len, hidden_dim) + image_part = torch.randn(batch_size, image_len, hidden_dim) + tensor = torch.cat([text_part, image_part], dim=1) + + assert tensor.shape == (batch_size, text_len + image_len, hidden_dim) + + # Simulate partial sharding for world_size=4, rank=1 + world_size = 4 + rank = 1 + dim = 1 + + # Extract parts + text_kept = tensor.narrow(dim, 0, text_len) + image_full = tensor.narrow(dim, text_len, image_len) + + # Shard only image part + image_chunks = image_full.chunk(world_size, dim=dim) + image_sharded = image_chunks[rank] + + # Concatenate back + result = torch.cat([text_kept, image_sharded], dim=dim) + + expected_len = text_len + image_len // world_size + assert result.shape == (batch_size, expected_len, hidden_dim) + assert result.shape == (2, 8 + 4, 64) # text_len + image_len/4 + + def test_gather_tensor_simulation(self): + """Simulate tensor gathering without distributed backend.""" + # Create sharded tensors (as if from different ranks) + batch_size, shard_seq_len, hidden_dim = 2, 4, 64 + world_size = 4 + + shards = [torch.randn(batch_size, shard_seq_len, hidden_dim) for _ in range(world_size)] + + # Simulate gathering (concatenate along dim 1) + gathered = torch.cat(shards, dim=1) + + assert gathered.shape == (batch_size, shard_seq_len * world_size, hidden_dim) + assert gathered.shape == (2, 16, 64) + + def test_padding_simulation(self): + """Simulate padding for non-divisible sequence lengths.""" + # Create tensor with non-divisible sequence length + batch_size, seq_len, hidden_dim = 2, 17, 64 # 17 not divisible by 4 + tensor = torch.randn(batch_size, seq_len, hidden_dim) + + world_size = 4 + dim = 1 + + # Calculate padding needed + remainder = seq_len % world_size + if remainder != 0: + pad_size = world_size - remainder + else: + pad_size = 0 + + assert pad_size == 3 # 17 + 3 = 20, divisible by 4 + + # Pad tensor + if pad_size > 0: + pad_shape = list(tensor.shape) + pad_shape[dim] = pad_size + padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + padded = torch.cat([tensor, padding], dim=dim) + else: + padded = tensor + + assert padded.shape == (batch_size, seq_len + pad_size, hidden_dim) + assert padded.shape == (2, 20, 64) + + # Now can shard evenly + chunks = padded.chunk(world_size, dim=dim) + assert all(c.shape == (2, 5, 64) for c in chunks) + + +# ============================================================================= +# Additional tests for sequence_parallel.py coverage +# ============================================================================= + + +@pytest.mark.cpu +class TestUnwrapModule: + """Test _unwrap_module function.""" + + def test_unwrap_simple_module(self): + """Test that a simple module returns itself.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _unwrap_module + + module = nn.Linear(10, 10) + result = _unwrap_module(module) + assert result is module + + def test_unwrap_sequential_single(self): + """Test unwrapping a Sequential with single child.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _unwrap_module + + inner = nn.Linear(10, 10) + wrapper = nn.Sequential(inner) + result = _unwrap_module(wrapper) + # Should unwrap to the inner module + assert result is inner + + def test_unwrap_nested_wrapper(self): + """Test unwrapping nested single-child wrappers.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _unwrap_module + + inner = nn.Linear(10, 10) + wrapper1 = nn.Sequential(inner) + wrapper2 = nn.Sequential(wrapper1) + result = _unwrap_module(wrapper2) + # Should fully unwrap to the innermost module + assert result is inner + + +@pytest.mark.cpu +class TestSequenceParallelSplitHookInit: + """Test SequenceParallelSplitHook initialization and setup.""" + + def test_hook_init(self): + """Test SequenceParallelSplitHook initialization.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook + + metadata = { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + } + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelSplitHook(metadata, config) + assert hook.metadata == metadata + assert hook.config == config + assert hook.module_forward_metadata is None # Not initialized until initialize_hook + + def test_hook_initialize(self): + """Test SequenceParallelSplitHook.initialize_hook.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook + + class DummyModule(nn.Module): + def forward(self, hidden_states, encoder_hidden_states): + return hidden_states + + metadata = { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + } + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelSplitHook(metadata, config) + module = DummyModule() + + # Initialize hook + result = hook.initialize_hook(module) + assert result is module + assert hook.module_forward_metadata is not None + assert hook.module_forward_metadata._cls is DummyModule + + +@pytest.mark.cpu +class TestSequenceParallelGatherHookInit: + """Test SequenceParallelGatherHook initialization.""" + + def test_hook_init_single_output(self): + """Test SequenceParallelGatherHook with single output.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelGatherHook + + metadata = SequenceParallelOutput(gather_dim=1, expected_dims=3) + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelGatherHook(metadata, config) + # Single output should be wrapped in a list + assert isinstance(hook.metadata, list) + assert len(hook.metadata) == 1 + assert hook.metadata[0].gather_dim == 1 + + def test_hook_init_multiple_outputs(self): + """Test SequenceParallelGatherHook with multiple outputs.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelGatherHook + + metadata = [ + SequenceParallelOutput(gather_dim=1, expected_dims=3), + SequenceParallelOutput(gather_dim=2, expected_dims=4), + ] + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelGatherHook(metadata, config) + assert len(hook.metadata) == 2 + assert hook.metadata[0].gather_dim == 1 + assert hook.metadata[1].gather_dim == 2 + + +@pytest.mark.cpu +class TestResolveTextLen: + """Test _resolve_text_len in SequenceParallelSplitHook.""" + + def test_resolve_int_source(self): + """Test resolving text length from integer source.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook + + class DummyModule(nn.Module): + def forward(self, x, txt_ids): + return x + + partial_input = SequenceParallelPartialInput( + split_dim=1, + text_len_source=256, # Fixed integer + expected_dims=3, + ) + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelSplitHook({"x": partial_input}, config) + hook.initialize_hook(DummyModule()) + + # Resolve with integer source + text_len = hook._resolve_text_len(partial_input, (), {}) + assert text_len == 256 + + def test_resolve_string_source_from_tensor(self): + """Test resolving text length from tensor parameter.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook + + class DummyModule(nn.Module): + def forward(self, x, txt_ids): + return x + + partial_input = SequenceParallelPartialInput( + split_dim=1, + text_len_source="txt_ids", # Get from parameter + expected_dims=3, + ) + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelSplitHook({"x": partial_input}, config) + hook.initialize_hook(DummyModule()) + + # Provide txt_ids tensor + txt_ids = torch.randn(128, 64) # shape[0] = 128 + kwargs = {"txt_ids": txt_ids} + + text_len = hook._resolve_text_len(partial_input, (), kwargs) + assert text_len == 128 + + def test_resolve_text_len_caching(self): + """Test that text length is cached.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook + + class DummyModule(nn.Module): + def forward(self, x, txt_ids): + return x + + partial_input = SequenceParallelPartialInput( + split_dim=1, + text_len_source="txt_ids", + expected_dims=3, + ) + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelSplitHook({"x": partial_input}, config) + hook.initialize_hook(DummyModule()) + + txt_ids = torch.randn(64, 32) + kwargs = {"txt_ids": txt_ids} + + # First call - should populate cache + hook._resolve_text_len(partial_input, (), kwargs) + assert "txt_ids" in hook._text_len_cache + assert hook._text_len_cache["txt_ids"] == 64 + + # Second call - should use cache + text_len = hook._resolve_text_len(partial_input, (), kwargs) + assert text_len == 64 + + +@pytest.mark.cpu +class TestHookNameTemplates: + """Test hook name template generation.""" + + def test_input_hook_name(self): + """Test input hook name format.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _SP_INPUT_HOOK_TEMPLATE + + name = _SP_INPUT_HOOK_TEMPLATE.format("blocks.0") + assert name == "sp_input---blocks.0" + + def test_output_hook_name(self): + """Test output hook name format.""" + from vllm_omni.diffusion.hooks.sequence_parallel import _SP_OUTPUT_HOOK_TEMPLATE + + name = _SP_OUTPUT_HOOK_TEMPLATE.format("proj_out") + assert name == "sp_output---proj_out" + + +@pytest.mark.cpu +class TestApplyRemoveSequenceParallel: + """Test apply_sequence_parallel and remove_sequence_parallel functions.""" + + def test_apply_sp_registers_hooks(self): + """Test that apply_sequence_parallel registers hooks on modules.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import ( + _SP_INPUT_HOOK_TEMPLATE, + _SP_OUTPUT_HOOK_TEMPLATE, + apply_sequence_parallel, + ) + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.proj_in = nn.Linear(10, 10) + self.proj_out = nn.Linear(10, 10) + + def forward(self, hidden_states): + x = self.proj_in(hidden_states) + return self.proj_out(x) + + model = SimpleModel() + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + plan = { + "proj_in": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)}, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + # Apply SP + apply_sequence_parallel(model, config, plan) + + # Check hooks are registered + + assert hasattr(model.proj_in, "_hook_registry") + assert hasattr(model.proj_out, "_hook_registry") + + proj_in_registry = model.proj_in._hook_registry + proj_out_registry = model.proj_out._hook_registry + + assert _SP_INPUT_HOOK_TEMPLATE.format("proj_in") in proj_in_registry._hooks + assert _SP_OUTPUT_HOOK_TEMPLATE.format("proj_out") in proj_out_registry._hooks + + def test_remove_sp_removes_hooks(self): + """Test that remove_sequence_parallel removes hooks from modules.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import ( + _SP_INPUT_HOOK_TEMPLATE, + _SP_OUTPUT_HOOK_TEMPLATE, + apply_sequence_parallel, + remove_sequence_parallel, + ) + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.proj_in = nn.Linear(10, 10) + self.proj_out = nn.Linear(10, 10) + + def forward(self, hidden_states): + x = self.proj_in(hidden_states) + return self.proj_out(x) + + model = SimpleModel() + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + plan = { + "proj_in": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)}, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + # Apply then remove SP + apply_sequence_parallel(model, config, plan) + remove_sequence_parallel(model, plan) + + # Check hooks are removed + proj_in_registry = model.proj_in._hook_registry + proj_out_registry = model.proj_out._hook_registry + + assert _SP_INPUT_HOOK_TEMPLATE.format("proj_in") not in proj_in_registry._hooks + assert _SP_OUTPUT_HOOK_TEMPLATE.format("proj_out") not in proj_out_registry._hooks + + def test_apply_sp_with_wildcard(self): + """Test apply_sequence_parallel with wildcard module names.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import ( + _SP_INPUT_HOOK_TEMPLATE, + apply_sequence_parallel, + ) + + class Block(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([Block() for _ in range(3)]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + model = Model() + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + plan = { + "blocks.*": {"x": SequenceParallelInput(split_dim=1, expected_dims=3)}, + } + + # Apply SP + apply_sequence_parallel(model, config, plan) + + # Check all blocks have hooks registered + for i, block in enumerate(model.blocks): + assert hasattr(block, "_hook_registry") + registry = block._hook_registry + assert _SP_INPUT_HOOK_TEMPLATE.format("blocks.*") in registry._hooks + + +@pytest.mark.cpu +class TestDimensionValidation: + """Test expected_dims validation in hooks.""" + + def test_skip_shard_on_wrong_dims(self): + """Test that sharding is skipped when tensor dims don't match expected.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook + + class DummyModule(nn.Module): + def forward(self, x): + return x + + # Expect 3D tensor + metadata = { + "x": SequenceParallelInput(split_dim=1, expected_dims=3), + } + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1) + + hook = SequenceParallelSplitHook(metadata, config) + hook.initialize_hook(DummyModule()) + + # Provide 4D tensor (wrong dims) + tensor_4d = torch.randn(2, 4, 8, 16) + + # _prepare_sp_input should return tensor unchanged when dims don't match + result = hook._prepare_sp_input(tensor_4d, metadata["x"], (), {}) + # Since expected_dims=3 but tensor has 4 dims, should return original + assert result.shape == tensor_4d.shape + + +@pytest.mark.cpu +class TestSequenceParallelConfig: + """Test SequenceParallelConfig dataclass.""" + + def test_config_defaults_invalid(self): + """Test that SequenceParallelConfig with default values raises error. + + At least one of ulysses_degree or ring_degree must be > 1 to enable SP. + """ + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + + with pytest.raises(ValueError, match="must be > 1"): + SequenceParallelConfig() # Both defaults are 1, which is invalid + + def test_config_ulysses_only(self): + """Test SequenceParallelConfig with Ulysses only.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + + config = SequenceParallelConfig(ulysses_degree=4, ring_degree=1) + assert config.sequence_parallel_size == 4 + + def test_config_ring_only(self): + """Test SequenceParallelConfig with Ring only.""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + + config = SequenceParallelConfig(ulysses_degree=1, ring_degree=4) + assert config.sequence_parallel_size == 4 + + def test_config_hybrid(self): + """Test SequenceParallelConfig with hybrid (Ulysses + Ring).""" + from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig + + config = SequenceParallelConfig(ulysses_degree=2, ring_degree=4) + assert config.sequence_parallel_size == 8 diff --git a/tests/diffusion/lora/test_base_linear.py b/tests/diffusion/lora/test_base_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..42bdf6526a5feef8263f1edfd5aff149e5196f43 --- /dev/null +++ b/tests/diffusion/lora/test_base_linear.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from vllm_omni.diffusion.lora.layers.base_linear import DiffusionBaseLinearLayerWithLoRA + + +@dataclass +class _DummyLoRAConfig: + fully_sharded_loras: bool = False + + +class _DummyQuantMethod: + def __init__(self, weight: torch.Tensor): + self._weight = weight + + def apply(self, _base_layer, x: torch.Tensor, bias: torch.Tensor | None): + y = x @ self._weight.t() + if bias is not None: + y = y + bias + return y + + +def test_diffusion_base_linear_apply_multi_slice(): + # Build a fake diffusion LoRA layer with 2 slices and rank=2. + layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA) + layer.tp_size = 1 + layer.lora_config = _DummyLoRAConfig() + + in_dim = 3 + out_slices = (2, 1) + rank = 2 + + # Base weight: identity-ish mapping to make base output easy to reason about. + base_weight = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + layer.base_layer = type("Base", (), {})() + layer.base_layer.quant_method = _DummyQuantMethod(base_weight) + + # Allocate stacked weights: (max_loras=1, 1, rank, in_dim) and (1, 1, out_slice, rank) + a0 = torch.zeros((1, 1, rank, in_dim)) + b0 = torch.zeros((1, 1, out_slices[0], rank)) + a1 = torch.zeros((1, 1, rank, in_dim)) + b1 = torch.zeros((1, 1, out_slices[1], rank)) + + # Slice 0: delta0 = (x @ A0.T) @ B0.T + A0 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # (2, 3) + B0 = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # (2, 2) + a0[0, 0, :, :] = A0 + b0[0, 0, :, :] = B0 + + # Slice 1: delta1 = (x @ A1.T) @ B1.T + A1 = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) # (2, 3) + B1 = torch.tensor([[2.0, 0.0]]) # (1, 2) + a1[0, 0, :, :] = A1 + b1[0, 0, :, :] = B1 + + layer.lora_a_stacked = (a0, a1) + layer.lora_b_stacked = (b0, b1) + layer.output_slices = out_slices + + x = torch.tensor([[1.0, 2.0, 3.0]]) + out = layer.apply(x) + + # Base output is identity: [1,2,3] + base_out = x @ base_weight.t() + # delta0: + # (x @ A0.T) = [1,2] + # [1,2] @ B0.T = [1,2] + delta0 = torch.tensor([[1.0, 2.0]]) + # delta1: + # (x @ A1.T) = [3,1] + # [3,1] @ B1.T = [6] + delta1 = torch.tensor([[6.0]]) + expected = torch.cat([base_out[:, :2] + delta0, base_out[:, 2:3] + delta1], dim=-1) + assert torch.allclose(out, expected) + + +def test_diffusion_base_linear_reset_lora_disables_fast_path(monkeypatch): + # Verify that after reset_lora(), apply() skips LoRA matmuls even if the + # LoRA tensors are still allocated and non-empty. + from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA + + layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA) + layer.tp_size = 1 + layer.lora_config = _DummyLoRAConfig() + + in_dim = 2 + out_dim = 2 + rank = 1 + + base_weight = torch.eye(in_dim) + layer.base_layer = type("Base", (), {})() + layer.base_layer.quant_method = _DummyQuantMethod(base_weight) + + a = torch.ones((1, 1, rank, in_dim)) + b = torch.tensor([[[[1.0], [2.0]]]]) # (1,1,out_dim,rank) + + layer.lora_a_stacked = (a,) + layer.lora_b_stacked = (b,) + layer.output_slices = (out_dim,) + layer._diffusion_lora_active_slices = (True,) + + x = torch.tensor([[1.0, 2.0]]) + out_active = layer.apply(x) + assert torch.allclose(out_active, torch.tensor([[4.0, 8.0]])) + + monkeypatch.setattr(BaseLinearLayerWithLoRA, "reset_lora", lambda self, index: None) + layer.reset_lora(0) + + assert layer._diffusion_lora_active_slices == (False,) + out_inactive = layer.apply(x) + assert torch.allclose(out_inactive, x) + + +def test_diffusion_base_linear_apply_respects_inactive_slices(): + # Build a fake diffusion LoRA layer with 2 slices and rank=2. + layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA) + layer.tp_size = 1 + layer.lora_config = _DummyLoRAConfig() + + in_dim = 3 + out_slices = (2, 1) + rank = 2 + + base_weight = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + layer.base_layer = type("Base", (), {})() + layer.base_layer.quant_method = _DummyQuantMethod(base_weight) + + a0 = torch.zeros((1, 1, rank, in_dim)) + b0 = torch.zeros((1, 1, out_slices[0], rank)) + a1 = torch.zeros((1, 1, rank, in_dim)) + b1 = torch.zeros((1, 1, out_slices[1], rank)) + + A0 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # (2, 3) + B0 = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # (2, 2) + a0[0, 0, :, :] = A0 + b0[0, 0, :, :] = B0 + + A1 = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) # (2, 3) + B1 = torch.tensor([[2.0, 0.0]]) # (1, 2) + a1[0, 0, :, :] = A1 + b1[0, 0, :, :] = B1 + + layer.lora_a_stacked = (a0, a1) + layer.lora_b_stacked = (b0, b1) + layer.output_slices = out_slices + layer._diffusion_lora_active_slices = (True, False) + + x = torch.tensor([[1.0, 2.0, 3.0]]) + out = layer.apply(x) + + # Only the first slice should be adapted. + expected = torch.tensor([[2.0, 4.0, 3.0]]) + assert torch.allclose(out, expected) diff --git a/tests/diffusion/lora/test_lora_manager.py b/tests/diffusion/lora/test_lora_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..84fafe3bc9e7e6884a10aebc24a99e99f8afa653 --- /dev/null +++ b/tests/diffusion/lora/test_lora_manager.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch +from vllm.lora.lora_weights import LoRALayerWeights +from vllm.lora.utils import get_supported_lora_modules +from vllm.model_executor.layers.linear import LinearBase + +from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager +from vllm_omni.lora.request import LoRARequest + + +class _DummyLoRALayer: + def __init__(self, n_slices: int, output_slices: tuple[int, ...]): + self.n_slices = n_slices + self.output_slices = output_slices + self.set_calls: list[ + tuple[list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor] + ] = [] + self.reset_calls: int = 0 + + def set_lora(self, index: int, lora_a, lora_b): + assert index == 0 + self.set_calls.append((lora_a, lora_b)) + + def reset_lora(self, index: int): + assert index == 0 + self.reset_calls += 1 + + +class _FakeLinearBase(LinearBase): + def __init__(self): + torch.nn.Module.__init__(self) + + +def test_lora_manager_supported_modules_are_stable_with_wrapped_layers(monkeypatch): + # Simulate a pipeline that already contains LoRA wrappers where the original + # LinearBase is nested under ".base_layer". + import vllm_omni.diffusion.lora.manager as manager_mod + + class _DummyBaseLayerWithLoRA(torch.nn.Module): + def __init__(self, base_layer: torch.nn.Module): + super().__init__() + self.base_layer = base_layer + + monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA) + + pipeline = torch.nn.Module() + pipeline.transformer = torch.nn.Module() + pipeline.transformer.foo = _DummyBaseLayerWithLoRA(_FakeLinearBase()) + + # vLLM helper would see only the nested LinearBase and yield "base_layer". + assert get_supported_lora_modules(pipeline) == ["base_layer"] + + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + assert "foo" in manager._supported_lora_modules + assert "base_layer" not in manager._supported_lora_modules + + +def test_lora_manager_replace_layers_does_not_rewrap_base_layer(monkeypatch): + import vllm_omni.diffusion.lora.manager as manager_mod + + class _DummyBaseLayerWithLoRA(torch.nn.Module): + def __init__(self, base_layer: torch.nn.Module): + super().__init__() + self.base_layer = base_layer + + monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA) + + def _fake_from_layer_diffusion(*, layer: torch.nn.Module, **_kwargs): + if isinstance(layer, _FakeLinearBase): + return _DummyBaseLayerWithLoRA(layer) + return layer + + replace_calls: list[str] = [] + + def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule: torch.nn.Module): + replace_calls.append(module_name) + setattr(root, module_name, submodule) + + monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion) + monkeypatch.setattr(manager_mod, "replace_submodule", _fake_replace_submodule) + + pipeline = torch.nn.Module() + pipeline.transformer = torch.nn.Module() + pipeline.transformer.foo = _FakeLinearBase() + + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + peft_helper = type("_PH", (), {"r": 1})() + + manager._replace_layers_with_lora(peft_helper) + manager._replace_layers_with_lora(peft_helper) + + # Only the top-level layer should have been replaced; nested ".base_layer" + # must be skipped to avoid nesting LoRA wrappers. + assert replace_calls == ["foo"] + + +def test_lora_manager_replaces_packed_layer_when_targeting_sublayers(monkeypatch): + import vllm_omni.diffusion.lora.manager as manager_mod + + class _DummyBaseLayerWithLoRA(torch.nn.Module): + def __init__(self, base_layer: torch.nn.Module): + super().__init__() + self.base_layer = base_layer + + monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA) + + def _fake_from_layer_diffusion(*, layer: torch.nn.Module, **_kwargs): + return _DummyBaseLayerWithLoRA(layer) + + replace_calls: list[str] = [] + + def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule: torch.nn.Module): + replace_calls.append(module_name) + setattr(root, module_name, submodule) + + monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion) + monkeypatch.setattr(manager_mod, "replace_submodule", _fake_replace_submodule) + + pipeline = torch.nn.Module() + pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]} + pipeline.transformer = torch.nn.Module() + pipeline.transformer.to_qkv = _FakeLinearBase() + + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + # Treat the dummy layer as a packed 3-slice projection so the manager uses + # `packed_modules_mapping` to decide replacement based on target_modules. + monkeypatch.setattr(manager, "_get_packed_modules_list", lambda _module: ["q", "k", "v"]) + + peft_helper = type("_PH", (), {"r": 1, "target_modules": ["to_q"]})() + manager._replace_layers_with_lora(peft_helper) + + assert replace_calls == ["to_qkv"] + + +def test_lora_manager_activates_fused_lora_on_packed_layer(): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + packed_layer = _DummyLoRALayer(n_slices=3, output_slices=(2, 1, 1)) + manager._lora_modules = {"transformer.blocks.0.attn.to_qkv": packed_layer} + + rank = 2 + A = torch.ones((rank, 4)) + B = torch.arange(0, sum(packed_layer.output_slices) * rank, dtype=torch.bfloat16).view(-1, rank) + lora = LoRALayerWeights( + module_name="transformer.blocks.0.attn.to_qkv", + rank=rank, + lora_alpha=rank, + lora_a=A, + lora_b=B, + ) + manager._registered_adapters = { + 7: type( + "LM", + (), + { + "id": 7, + "loras": {"transformer.blocks.0.attn.to_qkv": lora}, + "get_lora": lambda self, k: self.loras.get(k), + }, + )() + } + manager._adapter_scales = {7: 0.5} + + manager._activate_adapter(7) + + assert packed_layer.reset_calls == 0 + assert len(packed_layer.set_calls) == 1 + lora_a_list, lora_b_list = packed_layer.set_calls[0] + assert isinstance(lora_a_list, list) + assert isinstance(lora_b_list, list) + assert len(lora_a_list) == 3 + assert len(lora_b_list) == 3 + assert all(torch.allclose(a, A) for a in lora_a_list) + # B should be split into 3 slices and scaled. + b0, b1, b2 = lora_b_list + assert b0.shape[0] == 2 and b1.shape[0] == 1 and b2.shape[0] == 1 + assert torch.allclose(torch.cat([b0, b1, b2], dim=0), B * 0.5) + + +def test_lora_manager_activates_packed_lora_from_sublayers(): + pipeline = torch.nn.Module() + pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]} + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + packed_layer = _DummyLoRALayer(n_slices=3, output_slices=(2, 1, 1)) + manager._lora_modules = {"transformer.blocks.0.attn.to_qkv": packed_layer} + + rank = 2 + loras: dict[str, LoRALayerWeights] = {} + for name, out_dim in zip(["to_q", "to_k", "to_v"], [2, 1, 1]): + loras[f"transformer.blocks.0.attn.{name}"] = LoRALayerWeights( + module_name=f"transformer.blocks.0.attn.{name}", + rank=rank, + lora_alpha=rank, + lora_a=torch.ones((rank, 4)) * (1 if name == "to_q" else 2), + lora_b=torch.ones((out_dim, rank)) * (3 if name == "to_q" else 4), + ) + + manager._registered_adapters = { + 1: type("LM", (), {"id": 1, "loras": loras, "get_lora": lambda self, k: self.loras.get(k)})() + } + manager._adapter_scales = {1: 2.0} + + manager._activate_adapter(1) + + assert packed_layer.reset_calls == 0 + assert len(packed_layer.set_calls) == 1 + lora_a_list, lora_b_list = packed_layer.set_calls[0] + assert isinstance(lora_a_list, list) + assert isinstance(lora_b_list, list) + assert len(lora_a_list) == 3 + assert len(lora_b_list) == 3 + # Scale should apply to B only. + assert torch.allclose(lora_b_list[0], torch.ones((2, rank)) * 3 * 2.0) + assert torch.allclose(lora_b_list[1], torch.ones((1, rank)) * 4 * 2.0) + assert torch.allclose(lora_b_list[2], torch.ones((1, rank)) * 4 * 2.0) + + +def _dummy_lora_request(adapter_id: int) -> LoRARequest: + return LoRARequest( + lora_name=f"adapter_{adapter_id}", + lora_int_id=adapter_id, + lora_path=f"/tmp/adapter_{adapter_id}", + ) + + +def test_lora_manager_evicts_lru_adapter_when_cache_full(monkeypatch): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=2, + ) + + def _fake_load(_req: LoRARequest): + lora_model = type("LM", (), {"id": _req.lora_int_id})() + peft_helper = type("PH", (), {})() + return lora_model, peft_helper + + monkeypatch.setattr(manager, "_load_adapter", _fake_load) + monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) + monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None) + + req1 = _dummy_lora_request(1) + req2 = _dummy_lora_request(2) + req3 = _dummy_lora_request(3) + + manager.set_active_adapter(req1, lora_scale=1.0) + manager.set_active_adapter(req2, lora_scale=1.0) + + # Touch adapter 1 so adapter 2 becomes LRU. + manager.set_active_adapter(req1, lora_scale=1.0) + + manager.set_active_adapter(req3, lora_scale=1.0) + + assert set(manager.list_adapters()) == {1, 3} + + +def test_lora_manager_does_not_evict_pinned_adapter(monkeypatch): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=2, + ) + + def _fake_load(_req: LoRARequest): + lora_model = type("LM", (), {"id": _req.lora_int_id})() + peft_helper = type("PH", (), {})() + return lora_model, peft_helper + + monkeypatch.setattr(manager, "_load_adapter", _fake_load) + monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) + monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None) + + manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0) + assert manager.pin_adapter(1) + + manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0) + manager.set_active_adapter(_dummy_lora_request(3), lora_scale=1.0) + + assert set(manager.list_adapters()) == {1, 3} + + +def test_lora_manager_warns_when_all_adapters_pinned(monkeypatch): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=2, + ) + + def _fake_load(_req: LoRARequest): + lora_model = type("LM", (), {"id": _req.lora_int_id})() + peft_helper = type("PH", (), {})() + return lora_model, peft_helper + + monkeypatch.setattr(manager, "_load_adapter", _fake_load) + monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) + monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None) + + manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0) + manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0) + + assert manager.pin_adapter(1) + assert manager.pin_adapter(2) + + manager.max_cached_adapters = 1 + manager._evict_if_needed() + + assert set(manager.list_adapters()) == {1, 2} diff --git a/tests/diffusion/models/z_image/test_zimage_tp_constraints.py b/tests/diffusion/models/z_image/test_zimage_tp_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..a276274a0dd6f69ecb9be70f9bfb812ea1732287 --- /dev/null +++ b/tests/diffusion/models/z_image/test_zimage_tp_constraints.py @@ -0,0 +1,44 @@ +import pytest + +from vllm_omni.diffusion.models.z_image.z_image_transformer import validate_zimage_tp_constraints + + +def test_validate_zimage_tp_constraints_tp2_ok(): + ffn_hidden_dim, final_out_dims, supported_tp = validate_zimage_tp_constraints( + dim=3840, + n_heads=30, + n_kv_heads=30, + in_channels=16, + all_patch_size=(2,), + all_f_patch_size=(1,), + tensor_parallel_size=2, + ) + assert ffn_hidden_dim == 10240 + assert final_out_dims == [64] + assert supported_tp == [1, 2] + + +def test_validate_zimage_tp_constraints_tp4_fails_on_heads(): + with pytest.raises(ValueError, match=r"n_heads % tensor_parallel_size"): + validate_zimage_tp_constraints( + dim=3840, + n_heads=30, + n_kv_heads=30, + in_channels=16, + all_patch_size=(2,), + all_f_patch_size=(1,), + tensor_parallel_size=4, + ) + + +def test_validate_zimage_tp_constraints_tp3_fails_on_ffn_hidden_dim(): + with pytest.raises(ValueError, match=r"ffn_hidden_dim % tensor_parallel_size"): + validate_zimage_tp_constraints( + dim=3840, + n_heads=30, + n_kv_heads=30, + in_channels=16, + all_patch_size=(2,), + all_f_patch_size=(1,), + tensor_parallel_size=3, + ) diff --git a/tests/diffusion/test_diffusion_worker.py b/tests/diffusion/test_diffusion_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..220f210a3d5ac6b7050a23b8874739b9241ee09e --- /dev/null +++ b/tests/diffusion/test_diffusion_worker.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Unit tests for DiffusionWorker class. + +This module tests the DiffusionWorker implementation: +- load_weights: Loading model weights +- sleep: Putting worker into sleep mode (levels 1 and 2) +- wake_up: Waking worker from sleep mode +""" + +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker + + +@pytest.fixture +def mock_od_config(): + """Create a mock OmniDiffusionConfig.""" + config = Mock() + config.num_gpus = 1 + config.master_port = 12345 + config.enable_sleep_mode = False + config.cache_backend = None + config.cache_config = None + config.model = "test-model" + return config + + +@pytest.fixture +def mock_gpu_worker(mock_od_config): + """Create a DiffusionWorker with mocked initialization.""" + with patch.object(DiffusionWorker, "init_device"): + worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config) + # Mock the model_runner with pipeline + worker.model_runner = Mock() + worker.model_runner.pipeline = Mock() + worker.device = torch.device("cuda", 0) + worker._sleep_saved_buffers = {} + return worker + + +class TestDiffusionWorkerLoadWeights: + """Test DiffusionWorker.load_weights method.""" + + def test_load_weights_calls_pipeline(self, mock_gpu_worker): + """Test that load_weights delegates to model_runner.load_weights.""" + # Setup mock weights + mock_weights = [ + ("layer1.weight", torch.randn(10, 10)), + ("layer2.weight", torch.randn(20, 20)), + ] + expected_loaded = {"layer1.weight", "layer2.weight"} + + # Configure model_runner mock + mock_gpu_worker.model_runner.load_weights = Mock(return_value=expected_loaded) + + # Call load_weights + result = mock_gpu_worker.load_weights(mock_weights) + + # Verify model_runner.load_weights was called with the weights + mock_gpu_worker.model_runner.load_weights.assert_called_once_with(mock_weights) + assert result == expected_loaded + + def test_load_weights_empty_iterable(self, mock_gpu_worker): + """Test load_weights with empty weights iterable.""" + mock_gpu_worker.model_runner.load_weights = Mock(return_value=set()) + + result = mock_gpu_worker.load_weights([]) + + mock_gpu_worker.model_runner.load_weights.assert_called_once_with([]) + assert result == set() + + +class TestDiffusionWorkerSleep: + """Test DiffusionWorker.sleep method.""" + + @patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_sleep_level_1(self, mock_allocator_class, mock_platform, mock_gpu_worker): + """Test sleep mode level 1 (offload weights only).""" + # Setup memory info mocks + # Before sleep: 1GB free + # After sleep: 3GB free (freed 2GB) + mock_platform.get_free_memory.side_effect = [ + 1 * 1024**3, # Before sleep + 3 * 1024**3, # After sleep + ] + mock_platform.get_device_total_memory.return_value = 8 * 1024**3 + + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.sleep = Mock() + + # Call sleep with level 1 + result = mock_gpu_worker.sleep(level=1) + + # Verify sleep was called with correct tags + mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",)) + assert result is True + # Verify buffers were NOT saved (level 1 doesn't save buffers) + assert len(mock_gpu_worker._sleep_saved_buffers) == 0 + + @patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_sleep_level_2(self, mock_allocator_class, mock_platform, mock_gpu_worker): + """Test sleep mode level 2 (offload all, save buffers).""" + # Setup memory info mocks + mock_platform.get_free_memory.side_effect = [ + 1 * 1024**3, # Before sleep + 5 * 1024**3, # After sleep (freed 4GB) + ] + mock_platform.get_device_total_memory.return_value = 8 * 1024**3 + + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.sleep = Mock() + + # Mock pipeline buffers + mock_buffer1 = torch.randn(10, 10) + mock_buffer2 = torch.randn(20, 20) + mock_gpu_worker.model_runner.pipeline.named_buffers = Mock( + return_value=[ + ("buffer1", mock_buffer1), + ("buffer2", mock_buffer2), + ] + ) + + # Call sleep with level 2 + result = mock_gpu_worker.sleep(level=2) + + # Verify sleep was called with empty tags (offload all) + mock_allocator.sleep.assert_called_once_with(offload_tags=tuple()) + assert result is True + + # Verify buffers were saved + assert len(mock_gpu_worker._sleep_saved_buffers) == 2 + assert "buffer1" in mock_gpu_worker._sleep_saved_buffers + assert "buffer2" in mock_gpu_worker._sleep_saved_buffers + + @patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_platform, mock_gpu_worker): + """Test that sleep validates memory was actually freed.""" + # Simulate memory increase (should trigger assertion error) + mock_platform.get_free_memory.side_effect = [ + 3 * 1024**3, # Before sleep: 3GB free + 1 * 1024**3, # After sleep: 1GB free (negative freed!) + ] + mock_platform.get_device_total_memory.return_value = 8 * 1024**3 + + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.sleep = Mock() + + # This should raise an assertion error + with pytest.raises(AssertionError, match="Memory usage increased after sleeping"): + mock_gpu_worker.sleep(level=1) + + +class TestDiffusionWorkerWakeUp: + """Test DiffusionWorker.wake_up method.""" + + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_wake_up_without_buffers(self, mock_allocator_class, mock_gpu_worker): + """Test wake_up without saved buffers (level 1 sleep).""" + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.wake_up = Mock() + + # Ensure no saved buffers + mock_gpu_worker._sleep_saved_buffers = {} + + # Call wake_up + result = mock_gpu_worker.wake_up(tags=["weights"]) + + # Verify allocator.wake_up was called + mock_allocator.wake_up.assert_called_once_with(["weights"]) + assert result is True + + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_wake_up_with_buffers(self, mock_allocator_class, mock_gpu_worker): + """Test wake_up with saved buffers (level 2 sleep).""" + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.wake_up = Mock() + + # Create saved buffers + saved_buffer1 = torch.randn(10, 10) + saved_buffer2 = torch.randn(20, 20) + mock_gpu_worker._sleep_saved_buffers = { + "buffer1": saved_buffer1, + "buffer2": saved_buffer2, + } + + # Mock pipeline buffers (these will be restored) + mock_buffer1 = Mock() + mock_buffer1.data = Mock() + mock_buffer2 = Mock() + mock_buffer2.data = Mock() + + mock_gpu_worker.model_runner.pipeline.named_buffers = Mock( + return_value=[ + ("buffer1", mock_buffer1), + ("buffer2", mock_buffer2), + ] + ) + + # Call wake_up + result = mock_gpu_worker.wake_up(tags=None) + + # Verify allocator.wake_up was called + mock_allocator.wake_up.assert_called_once_with(None) + + # Verify buffers were restored + mock_buffer1.data.copy_.assert_called_once() + mock_buffer2.data.copy_.assert_called_once() + + # Verify saved buffers were cleared + assert len(mock_gpu_worker._sleep_saved_buffers) == 0 + assert result is True + + @patch("vllm.device_allocator.cumem.CuMemAllocator") + def test_wake_up_partial_buffer_restore(self, mock_allocator_class, mock_gpu_worker): + """Test wake_up only restores buffers that were saved.""" + # Setup allocator mock + mock_allocator = Mock() + mock_allocator_class.get_instance = Mock(return_value=mock_allocator) + mock_allocator.wake_up = Mock() + + # Only save buffer1, not buffer2 + saved_buffer1 = torch.randn(10, 10) + mock_gpu_worker._sleep_saved_buffers = { + "buffer1": saved_buffer1, + } + + # Mock pipeline has both buffers + mock_buffer1 = Mock() + mock_buffer1.data = Mock() + mock_buffer2 = Mock() + mock_buffer2.data = Mock() + + mock_gpu_worker.model_runner.pipeline.named_buffers = Mock( + return_value=[ + ("buffer1", mock_buffer1), + ("buffer2", mock_buffer2), + ] + ) + + # Call wake_up + result = mock_gpu_worker.wake_up() + + # Verify only buffer1 was restored + mock_buffer1.data.copy_.assert_called_once() + # buffer2 should NOT be restored since it wasn't saved + mock_buffer2.data.copy_.assert_not_called() + + assert result is True diff --git a/tests/distributed/omni_connectors/test_adapter_and_flow.py b/tests/distributed/omni_connectors/test_adapter_and_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc472393da1aa5252d6b2768ed1198737523bf6 --- /dev/null +++ b/tests/distributed/omni_connectors/test_adapter_and_flow.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock + +import pytest + +from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector, try_send_via_connector +from vllm_omni.distributed.omni_connectors.connectors.shm_connector import SharedMemoryConnector +from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec, OmniTransferConfig +from vllm_omni.distributed.omni_connectors.utils.initialization import get_connectors_config_for_stage + + +@pytest.fixture +def mock_objects(): + return {"connector": MagicMock(), "metrics": MagicMock(), "queue_fn": MagicMock()} + + +def test_send_success(mock_objects): + """Test try_send_via_connector success path.""" + # Setup + mock_connector = mock_objects["connector"] + mock_metrics = mock_objects["metrics"] + mock_queue_fn = mock_objects["queue_fn"] + + stage_id = 0 + next_stage_id = 1 + req_id = "req_123" + inputs = {"input_ids": [1, 2, 3]} + sampling_params = {"temperature": 0.7} + prompt = "test prompt" + + # Mock connector.put return + # Returns: (success, size, metadata) + mock_metadata = {"handle": "xyz"} + mock_connector.put.return_value = (True, 100, mock_metadata) + + # Execute + result = try_send_via_connector( + connector=mock_connector, + stage_id=stage_id, + next_stage_id=next_stage_id, + req_id=req_id, + next_inputs=inputs, + sampling_params=sampling_params, + original_prompt=prompt, + next_stage_queue_submit_fn=mock_queue_fn, + metrics=mock_metrics, + ) + + # Verify + assert result is True + + # 1. Verify connector.put called correctly + mock_connector.put.assert_called_once() + args, _ = mock_connector.put.call_args + assert args[0] == "0" # from_stage + assert args[1] == "1" # to_stage + assert args[2] == req_id + # Verify payload structure in put + payload = args[3] + assert payload["engine_inputs"] == inputs + assert payload["sampling_params"] == sampling_params + + # 2. Verify queue notification submitted + mock_queue_fn.assert_called_once() + notify_payload = mock_queue_fn.call_args[0][0] + assert notify_payload["request_id"] == req_id + assert notify_payload["from_connector"] is True + assert notify_payload["connector_metadata"] == mock_metadata + + # 3. Verify metrics recorded + mock_metrics.on_forward.assert_called_once() + + +def test_send_fail(mock_objects): + """Test try_send_via_connector when connector fails.""" + mock_connector = mock_objects["connector"] + mock_metrics = mock_objects["metrics"] + mock_queue_fn = mock_objects["queue_fn"] + + mock_connector.put.return_value = (False, 0, None) + + result = try_send_via_connector( + connector=mock_connector, + stage_id=0, + next_stage_id=1, + req_id="req_fail", + next_inputs={}, + sampling_params={}, + original_prompt="", + next_stage_queue_submit_fn=mock_queue_fn, + metrics=mock_metrics, + ) + + assert result is False + mock_queue_fn.assert_not_called() + + +def test_recv_success(mock_objects): + """Test try_recv_via_connector success path.""" + mock_connector = mock_objects["connector"] + + # Setup task received from queue + task = { + "request_id": "req_recv", + "from_connector": True, + "from_stage": "0", + "connector_metadata": {"handle": "xyz"}, + } + + # Setup connectors dict + connectors = {("0", "1"): mock_connector} + + # Mock connector.get return + expected_data = {"engine_inputs": {"ids": [1]}} + # get returns: (data_obj, size) + mock_connector.get.return_value = (expected_data, 50) + # serialize_obj needed for metrics calculation if size not returned directly + mock_connector.serialize_obj.return_value = b"bytes" + + # Execute + # We are stage 1 receiving from stage 0 + inputs, rx_metrics = try_recv_via_connector(task, connectors, stage_id=1) + + # Verify + assert inputs == expected_data["engine_inputs"] + assert rx_metrics is not None + mock_connector.get.assert_called_once_with("0", "1", "req_recv", metadata={"handle": "xyz"}) + + +def test_recv_no_connector(): + """Test recv fails when no connector exists for edge.""" + task = {"request_id": "req_missing", "from_connector": True, "from_stage": "0"} + connectors = {} # Empty connectors + + inputs, _ = try_recv_via_connector(task, connectors, stage_id=1) + assert inputs is None + + +def test_shm_connector_flow(): + """ + Verify the full flow: Send -> Adapter -> Connector -> Adapter -> Recv. + Using real SharedMemoryConnector (inline mode for simplicity). + """ + # 1. Setup Connector + config = {"shm_threshold_bytes": 1024} # Large threshold to use inline + connector = SharedMemoryConnector(config) + connectors_map = {("0", "1"): connector} + + # 2. Setup Data + stage_id = 0 + next_stage_id = 1 + req_id = "flow_req" + inputs = {"tokens": [10, 20, 30]} + sampling_params = {"n": 1} + + # Queue capture mechanism + queue_capture = [] + + def mock_submit(payload): + queue_capture.append(payload) + + mock_metrics = MagicMock() + + # 3. Send + success = try_send_via_connector( + connector=connector, + stage_id=stage_id, + next_stage_id=next_stage_id, + req_id=req_id, + next_inputs=inputs, + sampling_params=sampling_params, + original_prompt="prompt", + next_stage_queue_submit_fn=mock_submit, + metrics=mock_metrics, + ) + assert success is True + assert len(queue_capture) == 1 + + # 4. Recv + # The 'task' is what would be popped from the queue + received_task = queue_capture[0] + + # Verify queue payload contains what we expect + assert received_task["from_connector"] is True + assert received_task["from_stage"] == "0" + + # Decode + decoded_inputs, _ = try_recv_via_connector(received_task, connectors_map, stage_id=1) + + # 5. Verify Data Integrity + assert decoded_inputs == inputs + + +def test_get_connectors_for_stage(): + """Test filtering logic for stage config.""" + # Config has edges: 0->1, 1->2 + config = OmniTransferConfig(connectors={("0", "1"): ConnectorSpec(name="C1"), ("1", "2"): ConnectorSpec(name="C2")}) + + # Get config for Stage 1 + # Stage 1 receives from 0 (input) and sends to 2 (output) + # get_connectors_config_for_stage ONLY returns INPUT connectors for the worker to initialize + + stage_config = get_connectors_config_for_stage(config, stage_id=1) + + # Should contain "from_stage_0" + assert "from_stage_0" in stage_config + assert stage_config["from_stage_0"]["spec"]["name"] == "C1" + + # Should NOT contain "from_stage_1" or related to output + assert "from_stage_1" not in stage_config + + # Verify Stage 2 + stage_2_config = get_connectors_config_for_stage(config, stage_id=2) + assert "from_stage_1" in stage_2_config + assert stage_2_config["from_stage_1"]["spec"]["name"] == "C2" + + +def test_recv_with_missing_metadata(): + """Test recv when queue payload is malformed (missing metadata).""" + # Connector expects metadata but task doesn't have it + task = { + "request_id": "req_bad", + "from_connector": True, + "from_stage": "0", + # Missing "connector_metadata" + } + mock_conn = MagicMock() + # If get is called with None metadata, connector usually handles it or adapter handles exception + mock_conn.get.side_effect = Exception("Get failed") + + connectors = {("0", "1"): mock_conn} + + inputs, _ = try_recv_via_connector(task, connectors, stage_id=1) + assert inputs is None diff --git a/tests/distributed/omni_connectors/test_basic_connectors.py b/tests/distributed/omni_connectors/test_basic_connectors.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b4b2d013300d7e3783763d7d8cbb99a493a110 --- /dev/null +++ b/tests/distributed/omni_connectors/test_basic_connectors.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock + +import pytest + +from vllm_omni.distributed.omni_connectors.connectors.shm_connector import SharedMemoryConnector +from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory +from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec +from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + + +def test_basic_serialization(): + """Test basic msgpack serialization.""" + data = {"key": "value", "list": [1, 2, 3]} + serialized = OmniSerializer.serialize(data) + assert isinstance(serialized, bytes) + + deserialized = OmniSerializer.deserialize(serialized) + assert data == deserialized + + +def test_tensor_serialization(): + """Test torch.Tensor serialization.""" + import torch + + tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + serialized = OmniSerializer.serialize(tensor) + deserialized = OmniSerializer.deserialize(serialized) + + assert torch.equal(tensor, deserialized) + + +def test_ndarray_serialization(): + """Test numpy.ndarray serialization.""" + import numpy as np + + arr = np.array([[1, 2, 3], [4, 5, 6]]) + serialized = OmniSerializer.serialize(arr) + deserialized = OmniSerializer.deserialize(serialized) + + assert np.array_equal(arr, deserialized) + + +def test_create_shm_connector(): + """Test creating SharedMemoryConnector via Factory.""" + spec = ConnectorSpec(name="SharedMemoryConnector", extra={"shm_threshold_bytes": 1024}) + connector = OmniConnectorFactory.create_connector(spec) + assert isinstance(connector, SharedMemoryConnector) + assert connector.threshold == 1024 + + +def test_create_unknown_connector(): + """Test error when creating unknown connector.""" + spec = ConnectorSpec(name="UnknownConnector") + with pytest.raises(ValueError): + OmniConnectorFactory.create_connector(spec) + + +@pytest.fixture +def shm_connector(): + config = {"shm_threshold_bytes": 100} # Small threshold for testing + return SharedMemoryConnector(config) + + +def test_put_get_inline(shm_connector): + """Test inline transfer for small data.""" + data = {"small": "data"} + # Ensure data is smaller than threshold (100 bytes) + + success, size, metadata = shm_connector.put("stage_0", "stage_1", "req_1", data) + assert success is True + assert "inline_bytes" in metadata + assert "shm" not in metadata + + # Retrieve + retrieved_data, ret_size = shm_connector.get("stage_0", "stage_1", "req_1", metadata) + assert data == retrieved_data + assert size == ret_size + + +def test_put_get_shm(shm_connector, monkeypatch): + """Test SHM transfer logic for large data (Mocked).""" + # Create data larger than 100 bytes + data = {"large": "x" * 200} + + # Mock SHM return values + mock_handle = {"name": "test_shm", "size": 200} + mock_write = MagicMock(return_value=mock_handle) + monkeypatch.setattr("vllm_omni.distributed.omni_connectors.connectors.shm_connector.shm_write_bytes", mock_write) + + # When reading, return the serialized bytes of the data + serialized_data = shm_connector.serialize_obj(data) + mock_read = MagicMock(return_value=serialized_data) + monkeypatch.setattr("vllm_omni.distributed.omni_connectors.connectors.shm_connector.shm_read_bytes", mock_read) + + # Put + success, size, metadata = shm_connector.put("stage_0", "stage_1", "req_2", data) + + assert success is True + # Should use SHM because data > threshold + assert "shm" in metadata + assert metadata["shm"] == mock_handle + assert "inline_bytes" not in metadata + + mock_write.assert_called_once() + + # Get + retrieved_data, ret_size = shm_connector.get("stage_0", "stage_1", "req_2", metadata) + + assert data == retrieved_data + mock_read.assert_called_once_with(mock_handle) + + +def test_get_invalid_metadata(shm_connector): + """Test get with invalid metadata.""" + result = shm_connector.get("stage_0", "stage_1", "req_3", {}) + assert result is None + + result = shm_connector.get("stage_0", "stage_1", "req_3", {"unknown": "format"}) + assert result is None diff --git a/tests/distributed/omni_connectors/test_kv_flow.py b/tests/distributed/omni_connectors/test_kv_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7ff79ca54db38c51ef1e7fe4cb51a240486098 --- /dev/null +++ b/tests/distributed/omni_connectors/test_kv_flow.py @@ -0,0 +1,251 @@ +import pytest +import torch + +from tests.utils import hardware_test +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import ( + OmniKVCacheConfig, + OmniKVTransferManager, +) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + +class MockConnector: + def __init__(self): + self.store = {} + + def put(self, from_stage, to_stage, put_key, data): + # The manager now passes full key as put_key + key = f"{from_stage}->{to_stage}:{put_key}" + self.store[key] = data + return True, len(str(data)), None # (success, size, metadata) + + def get(self, from_stage, to_stage, get_key, metadata=None): + # The manager now passes full key as get_key + key = f"{from_stage}->{to_stage}:{get_key}" + if key in self.store: + return self.store[key], len(str(self.store[key])) + return None + + +@pytest.fixture +def mock_connector(): + return MockConnector() + + +@pytest.fixture +def kv_config(): + return OmniKVCacheConfig( + connector_config={"type": "mock"}, + from_stage="stage1", + to_stage="stage2", + stage_id="stage2", # Acting as receiver for some tests + need_recv_cache=True, + need_send_cache=True, + recv_timeout=1.0, # Short timeout for tests + ) + + +@pytest.fixture +def common_constants(): + return { + "num_layers": 2, + "num_heads": 4, + "head_dim": 16, + "block_size": 8, + "seq_len": 20, + "req_id": "req_test_1", + } + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_manager_extraction(kv_config, mock_connector, common_constants): + """Test extraction and sending logic in OmniKVTransferManager.""" + num_layers = common_constants["num_layers"] + block_size = common_constants["block_size"] + num_heads = common_constants["num_heads"] + head_dim = common_constants["head_dim"] + seq_len = common_constants["seq_len"] + req_id = common_constants["req_id"] + + num_blocks = 10 + kv_caches = [] + for _ in range(num_layers): + k_cache = torch.randn(num_blocks, block_size, num_heads, head_dim) + v_cache = torch.randn(num_blocks, block_size, num_heads, head_dim) + # Stack K and V to create [2, num_blocks, block_size, n_heads, head_dim] + layer_cache = torch.stack([k_cache, v_cache], dim=0) + kv_caches.append(layer_cache) + + block_ids = [1, 3, 5] + finished_reqs = {req_id: {"block_ids": block_ids, "seq_len": seq_len}} + + manager = OmniKVTransferManager(kv_config) + # Mock the connector factory or injection + manager._connector = mock_connector + + processed = manager.handle_finished_requests_kv_transfer(finished_reqs, kv_caches, block_size, "float32") + + assert req_id in processed + + # Check if data was put into connector + # Manager builds full key: omni_{from}_to_{to}_kv_cache_{req_id} + full_request_id = f"omni_stage1_to_stage2_kv_cache_{req_id}" + expected_key = f"stage1->stage2:{full_request_id}" + assert expected_key in mock_connector.store + + data = mock_connector.store[expected_key] + assert data["request_id"] == req_id + assert "layer_blocks" in data + assert len(data["layer_blocks"]["key_cache"]) == num_layers + + # Verify shape of extracted tensor: [seq_len, heads, dim] + # Note: Manager detaches and moves to CPU + expected_shape = (seq_len, num_heads, head_dim) + assert data["layer_blocks"]["key_cache"][0].shape == expected_shape + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_manager_reception(kv_config, mock_connector, common_constants): + """Test reception and injection logic in OmniKVTransferManager.""" + num_layers = common_constants["num_layers"] + block_size = common_constants["block_size"] + num_heads = common_constants["num_heads"] + head_dim = common_constants["head_dim"] + seq_len = common_constants["seq_len"] + req_id = common_constants["req_id"] + + expected_shape = (seq_len, num_heads, head_dim) + key_cache = [torch.randn(expected_shape) for _ in range(num_layers)] + value_cache = [torch.randn(expected_shape) for _ in range(num_layers)] + + layer_blocks = {"key_cache": key_cache, "value_cache": value_cache} + metadata = { + "block_size": block_size, + "num_layers": num_layers, + "dtype": "float32", + "seq_len": seq_len, + } + + data_to_receive = { + "request_id": req_id, + "layer_blocks": layer_blocks, + "metadata": metadata, + "block_ids": [], + } + + # In setUp, from_stage="stage1", stage_id="stage2". recv_stages=("stage1", "stage2") + + manager = OmniKVTransferManager(kv_config) + manager._connector = mock_connector + + # Pre-populate connector with data + # Manager builds full key: omni_{from}_to_{to}_kv_cache_{req_id} + full_request_id = f"omni_stage1_to_stage2_kv_cache_{req_id}" + store_key = f"stage1->stage2:{full_request_id}" + mock_connector.store[store_key] = data_to_receive + + req = OmniDiffusionRequest( + prompts=["test_recv"], + sampling_params=OmniDiffusionSamplingParams(), + request_ids=[req_id], + ) + # req.need_kv_receive = True # Implicitly handled by receive_kv_cache check? No, manager doesn't check it, runner does. + # But receive_kv_cache in manager checks request_id. Which we need to fix in manager next. + success = manager.receive_kv_cache(req, target_device=torch.device("cpu")) + + assert success + assert hasattr(req, "past_key_values") + assert hasattr(req, "kv_metadata") + + assert len(req.past_key_values.key_cache) == num_layers + assert torch.allclose(req.past_key_values.key_cache[0], key_cache[0]) + assert req.kv_metadata["seq_len"] == seq_len + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_integration_flow(common_constants): + """Simulate extraction -> connector -> reception.""" + num_layers = common_constants["num_layers"] + block_size = common_constants["block_size"] + num_heads = common_constants["num_heads"] + head_dim = common_constants["head_dim"] + req_id = common_constants["req_id"] + + sender_config = OmniKVCacheConfig( + connector_config={"type": "mock"}, from_stage="sender", to_stage="receiver", need_send_cache=True + ) + sender_manager = OmniKVTransferManager(sender_config) + connector = MockConnector() + sender_manager._connector = connector # Shared connector + + # Create Data + num_blocks = 5 + kv_caches = [] + for _ in range(num_layers): + layer = torch.randn(2, num_blocks, block_size, num_heads, head_dim) + kv_caches.append(layer) + + finished_reqs = {req_id: {"block_ids": [0, 1], "seq_len": 10}} + + # Send + sender_manager.handle_finished_requests_kv_transfer(finished_reqs, kv_caches, block_size, "float32") + + receiver_config = OmniKVCacheConfig( + connector_config={"type": "mock"}, + from_stage="sender", + stage_id="receiver", + need_recv_cache=True, + recv_timeout=1.0, + ) + receiver_manager = OmniKVTransferManager(receiver_config) + receiver_manager._connector = connector # Share the same mock connector instance + + req = OmniDiffusionRequest( + prompts=["test_integ"], + sampling_params=OmniDiffusionSamplingParams(), + request_ids=[req_id], + ) + + # Receive + success = receiver_manager.receive_kv_cache(req) + + # Verify + assert success + assert req.past_key_values is not None + assert req.kv_metadata["seq_len"] == 10 + + +@pytest.mark.cache +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_manager_extraction_no_connector(kv_config, common_constants): + """Test extraction when connector is unavailable (should still return IDs).""" + block_size = common_constants["block_size"] + req_id = common_constants["req_id"] + + manager = OmniKVTransferManager(kv_config) + # Force connector to be None + manager._connector = None + manager.config.connector_config = None + finished_reqs = {req_id: {"block_ids": [1, 2], "seq_len": 10}} + + processed = manager.handle_finished_requests_kv_transfer( + finished_reqs, kv_caches=[], block_size=block_size, cache_dtype="float32" + ) + + assert req_id in processed diff --git a/tests/distributed/omni_connectors/test_omni_connector_configs.py b/tests/distributed/omni_connectors/test_omni_connector_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..ca73e1ec4a31cfe516cb18cd2784f49f27a47382 --- /dev/null +++ b/tests/distributed/omni_connectors/test_omni_connector_configs.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + +import pytest + +# Use the new import path for initialization utilities +from vllm_omni.distributed.omni_connectors.utils.initialization import load_omni_transfer_config + + +def get_config_files(): + """Helper to find config files.""" + # Go up two levels from 'tests/distributed/omni_connectors' (approx) to 'vllm-omni' root + # Adjust based on file location: vllm-omni/tests/distributed/omni_connectors/test_omni_connector_configs.py + # This file is 4 levels deep from root if we count from tests? + # vllm-omni/tests/distributed/omni_connectors -> parent -> distributed -> parent -> tests -> parent -> vllm-omni + # Let's use resolve to be safe. + + # Path(__file__) = .../vllm-omni/tests/distributed/omni_connectors/test_omni_connector_configs.py + # .parent = omni_connectors + # .parent = distributed + # .parent = tests + # .parent = vllm-omni + + base_dir = Path(__file__).resolve().parent.parent.parent.parent + config_dir = base_dir / "vllm_omni" / "model_executor" / "stage_configs" + + if not config_dir.exists(): + return [] + + return list(config_dir.glob("qwen*.yaml")) + + +# Collect files at module level for parametrization +config_files = get_config_files() + + +@pytest.mark.skipif(len(config_files) == 0, reason="No config files found or directory missing") +@pytest.mark.parametrize("yaml_file", config_files, ids=lambda p: p.name) +def test_load_qwen_yaml_configs(yaml_file): + """ + Scan and test loading of all qwen*.yaml config files. + This ensures that existing stage configs are compatible with the OmniConnector system. + """ + print(f"Testing config load: {yaml_file.name}") + try: + # Attempt to load the config + # default_shm_threshold doesn't matter much for loading correctness, using default + config = load_omni_transfer_config(yaml_file) + + assert config is not None, "Config should not be None" + + # Basic validation + # Note: Some configs might not have 'runtime' or 'connectors' section if they rely on auto-shm + # but the load function should succeed regardless. + + # If the config defines stages, we expect connectors to be populated (either explicit or auto SHM) + # We can't strictly assert len(config.connectors) > 0 because a single stage pipeline might have 0 edges. + + print(f" -> Successfully loaded. Connectors: {len(config.connectors)}") + + except Exception as e: + pytest.fail(f"Failed to load config {yaml_file.name}: {e}") diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/e2e/offline_inference/__init__.py b/tests/e2e/offline_inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/e2e/offline_inference/conftest.py b/tests/e2e/offline_inference/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..89170983767a8b9672fa69991ba76b882306da7c --- /dev/null +++ b/tests/e2e/offline_inference/conftest.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Pytest configuration and fixtures for vllm-omni tests. +""" + +from typing import Any + +import pytest +from vllm import TextPrompt +from vllm.distributed.parallel_state import cleanup_dist_env_and_memory + +from tests.conftest import _run_post_test_cleanup, _run_pre_test_cleanup +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None +PromptImageInput = list[Any] | Any | None +PromptVideoInput = list[Any] | Any | None + + +class OmniRunner: + """ + Test runner for Omni models. + """ + + def __init__( + self, + model_name: str, + seed: int = 42, + stage_init_timeout: int = 300, + batch_timeout: int = 10, + init_timeout: int = 300, + shm_threshold_bytes: int = 65536, + log_stats: bool = False, + stage_configs_path: str | None = None, + **kwargs, + ) -> None: + """ + Initialize an OmniRunner for testing. + + Args: + model_name: The model name or path + seed: Random seed for reproducibility + stage_init_timeout: Timeout for initializing a single stage in seconds + batch_timeout: Timeout for batching in seconds + init_timeout: Timeout for initializing stages in seconds + shm_threshold_bytes: Threshold for using shared memory + log_stats: Enable detailed statistics logging + stage_configs_path: Optional path to YAML stage config file + **kwargs: Additional arguments passed to Omni + """ + cleanup_dist_env_and_memory() + _run_pre_test_cleanup(enable_force=True) + _run_post_test_cleanup(enable_force=True) + self.model_name = model_name + self.seed = seed + + self.omni = Omni( + model=model_name, + log_stats=log_stats, + stage_init_timeout=stage_init_timeout, + batch_timeout=batch_timeout, + init_timeout=init_timeout, + shm_threshold_bytes=shm_threshold_bytes, + stage_configs_path=stage_configs_path, + **kwargs, + ) + + def get_default_sampling_params_list(self) -> list[OmniSamplingParams]: + """ + Get a list of default sampling parameters for all stages. + + Returns: + List of SamplingParams with default decoding for each stage + """ + return [st.default_sampling_params for st in self.omni.stage_list] + + def get_omni_inputs( + self, + prompts: list[str] | str, + system_prompt: str | None = None, + audios: PromptAudioInput = None, + images: PromptImageInput = None, + videos: PromptVideoInput = None, + mm_processor_kwargs: dict[str, Any] | None = None, + modalities: list[str] | None = None, + ) -> list[TextPrompt]: + """ + Construct Omni input format from prompts and multimodal data. + + Args: + prompts: Text prompt(s) - either a single string or list of strings + system_prompt: Optional system prompt (defaults to Qwen system prompt) + audios: Audio input(s) - tuple of (audio_array, sample_rate) or list of tuples + images: Image input(s) - PIL Image or list of PIL Images + videos: Video input(s) - numpy array or list of numpy arrays + mm_processor_kwargs: Optional processor kwargs (e.g., use_audio_in_video) + + Returns: + List of prompt dictionaries suitable for Omni.generate() + """ + if system_prompt is None: + system_prompt = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech." + ) + + video_padding_token = "<|VIDEO|>" + image_padding_token = "<|IMAGE|>" + audio_padding_token = "<|AUDIO|>" + + if self.model_name == "Qwen/Qwen3-Omni-30B-A3B-Instruct": + video_padding_token = "<|video_pad|>" + image_padding_token = "<|image_pad|>" + audio_padding_token = "<|audio_pad|>" + + if isinstance(prompts, str): + prompts = [prompts] + + def _normalize_mm_input(mm_input, num_prompts): + if mm_input is None: + return [None] * num_prompts + if isinstance(mm_input, list): + if len(mm_input) != num_prompts: + raise ValueError( + f"Multimodal input list length ({len(mm_input)}) must match prompts length ({num_prompts})" + ) + return mm_input + return [mm_input] * num_prompts + + num_prompts = len(prompts) + audios_list = _normalize_mm_input(audios, num_prompts) + images_list = _normalize_mm_input(images, num_prompts) + videos_list = _normalize_mm_input(videos, num_prompts) + + omni_inputs = [] + for i, prompt_text in enumerate(prompts): + user_content = "" + multi_modal_data = {} + + audio = audios_list[i] + if audio is not None: + if isinstance(audio, list): + for _ in audio: + user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>" + multi_modal_data["audio"] = audio + else: + user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>" + multi_modal_data["audio"] = audio + + image = images_list[i] + if image is not None: + if isinstance(image, list): + for _ in image: + user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>" + multi_modal_data["image"] = image + else: + user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>" + multi_modal_data["image"] = image + + video = videos_list[i] + if video is not None: + if isinstance(video, list): + for _ in video: + user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>" + multi_modal_data["video"] = video + else: + user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>" + multi_modal_data["video"] = video + + user_content += prompt_text + + full_prompt = ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + input_dict: TextPrompt = {"prompt": full_prompt} + if multi_modal_data: + input_dict["multi_modal_data"] = multi_modal_data + if modalities: + input_dict["modalities"] = modalities + if mm_processor_kwargs: + input_dict["mm_processor_kwargs"] = mm_processor_kwargs + + omni_inputs.append(input_dict) + + return omni_inputs + + def generate( + self, + prompts: list[TextPrompt], + sampling_params_list: list[OmniSamplingParams] | None = None, + ) -> list[OmniRequestOutput]: + """ + Generate outputs for the given prompts. + + Args: + prompts: List of prompt dictionaries with 'prompt' and optionally + 'multi_modal_data' keys + sampling_params_list: List of sampling parameters for each stage. + If None, uses default parameters. + + Returns: + List of OmniRequestOutput objects from stages with final_output=True + """ + if sampling_params_list is None: + sampling_params_list = self.get_default_sampling_params_list() + + return self.omni.generate(prompts, sampling_params_list) + + def generate_multimodal( + self, + prompts: list[str] | str, + sampling_params_list: list[OmniSamplingParams] | None = None, + system_prompt: str | None = None, + audios: PromptAudioInput = None, + images: PromptImageInput = None, + videos: PromptVideoInput = None, + mm_processor_kwargs: dict[str, Any] | None = None, + modalities: list[str] | None = None, + ) -> list[OmniRequestOutput]: + """ + Convenience method to generate with multimodal inputs. + + Args: + prompts: Text prompt(s) + sampling_params_list: List of sampling parameters for each stage + system_prompt: Optional system prompt + audios: Audio input(s) + images: Image input(s) + videos: Video input(s) + mm_processor_kwargs: Optional processor kwargs + + Returns: + List of OmniRequestOutput objects from stages with final_output=True + """ + omni_inputs = self.get_omni_inputs( + prompts=prompts, + system_prompt=system_prompt, + audios=audios, + images=images, + videos=videos, + mm_processor_kwargs=mm_processor_kwargs, + modalities=modalities, + ) + return self.generate(omni_inputs, sampling_params_list) + + def generate_audio( + self, + prompts: list[str] | str, + sampling_params_list: list[OmniSamplingParams] | None = None, + system_prompt: str | None = None, + audios: PromptAudioInput = None, + mm_processor_kwargs: dict[str, Any] | None = None, + ) -> list[OmniRequestOutput]: + """ + Convenience method to generate with multimodal inputs. + Args: + prompts: Text prompt(s) + sampling_params_list: List of sampling parameters for each stage + system_prompt: Optional system prompt + audios: Audio input(s) + mm_processor_kwargs: Optional processor kwargs + Returns: + List of OmniRequestOutput objects from stages with final_output=True + """ + omni_inputs = self.get_omni_inputs( + prompts=prompts, + system_prompt=system_prompt, + audios=audios, + mm_processor_kwargs=mm_processor_kwargs, + ) + return self.generate(omni_inputs, sampling_params_list) + + def generate_video( + self, + prompts: list[str] | str, + sampling_params_list: list[OmniSamplingParams] | None = None, + system_prompt: str | None = None, + videos: PromptVideoInput = None, + mm_processor_kwargs: dict[str, Any] | None = None, + ) -> list[OmniRequestOutput]: + """ + Convenience method to generate with multimodal inputs. + Args: + prompts: Text prompt(s) + sampling_params_list: List of sampling parameters for each stage + system_prompt: Optional system prompt + videos: Video input(s) + mm_processor_kwargs: Optional processor kwargs + Returns: + List of OmniRequestOutput objects from stages with final_output=True + """ + omni_inputs = self.get_omni_inputs( + prompts=prompts, + system_prompt=system_prompt, + videos=videos, + mm_processor_kwargs=mm_processor_kwargs, + ) + return self.generate(omni_inputs, sampling_params_list) + + def generate_image( + self, + prompts: list[str] | str, + sampling_params_list: list[OmniSamplingParams] | None = None, + system_prompt: str | None = None, + images: PromptImageInput = None, + mm_processor_kwargs: dict[str, Any] | None = None, + ) -> list[OmniRequestOutput]: + """ + Convenience method to generate with multimodal inputs. + Args: + prompts: Text prompt(s) + sampling_params_list: List of sampling parameters for each stage + system_prompt: Optional system prompt + images: Image input(s) + mm_processor_kwargs: Optional processor kwargs + Returns: + List of OmniRequestOutput objects from stages with final_output=True + """ + omni_inputs = self.get_omni_inputs( + prompts=prompts, + system_prompt=system_prompt, + images=images, + mm_processor_kwargs=mm_processor_kwargs, + ) + return self.generate(omni_inputs, sampling_params_list) + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - cleanup resources.""" + self.close() + del self.omni + cleanup_dist_env_and_memory() + _run_post_test_cleanup(enable_force=True) + + def close(self): + """Close and cleanup the Omni instance.""" + if hasattr(self.omni, "close"): + self.omni.close() + + +@pytest.fixture(scope="session") +def omni_runner(): + return OmniRunner diff --git a/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c170dffd35f0fe5deab3c6e563017e61c807d79 --- /dev/null +++ b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml @@ -0,0 +1,85 @@ +# stage config for running BAGEL with Mooncake connector for CI e2e tests. +# This config is optimized for single GPU tests with Mooncake inter-stage communication. + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: BagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.35 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: mp + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: true + repetition_penalty: 1.05 + output_connectors: + to_stage_1: mooncake_connector + - stage_id: 1 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.55 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: mp + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + input_connectors: + from_stage_0: mooncake_connector + +# Top-level runtime config with Mooncake connector +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + connectors: + mooncake_connector: + name: MooncakeConnector + extra: + host: "${MOONCAKE_HOST}" + metadata_server: "http://${MOONCAKE_HOST}:${MOONCAKE_HTTP_PORT}/metadata" + master: "${MOONCAKE_HOST}:${MOONCAKE_RPC_PORT}" + segment: 64000000 + localbuf: 64000000 + proto: tcp + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef605a3a31191a36856fe5c34f4909369ba969cd --- /dev/null +++ b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml @@ -0,0 +1,83 @@ +# stage config for running BAGEL with SharedMemory connector for CI e2e tests. +# This config is optimized for single GPU tests with SharedMemory inter-stage communication. + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: BagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.35 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished #or special token generated + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.55 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + +# Runtime edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + # Distributed connectors configuration (optional) + # More connectors will be supported in the future. + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold + + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6b04258a4c96c297365495ec743a5c180112f22 --- /dev/null +++ b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml @@ -0,0 +1,104 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. + +# This config is optimized for CI e2e tests. +stage_args: + - stage_id: 0 + runtime: + process: true # Run this stage in a separate process + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 896 + max_num_batched_tokens: 896 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + skip_mm_profiling: true + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 896 + max_num_batched_tokens: 896 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + skip_mm_profiling: true + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + - stage_id: 2 + runtime: + process: true + devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/tests/e2e/offline_inference/stage_configs/qwen2_5_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/qwen2_5_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e093ec51b9947f31e510e11866625f131bfbe705 --- /dev/null +++ b/tests/e2e/offline_inference/stage_configs/qwen2_5_omni_ci.yaml @@ -0,0 +1,106 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. + +# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090). +# This config is optimized for CI e2e tests. +stage_args: + - stage_id: 0 + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 896 + max_num_batched_tokens: 896 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + skip_mm_profiling: true + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 896 + max_num_batched_tokens: 896 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + skip_mm_profiling: true + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + - stage_id: 2 + runtime: + process: true + devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + max_num_batched_tokens: 4069 + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..477e6e59f2921ccd030ca6dbedfb79931a89a04b --- /dev/null +++ b/tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml @@ -0,0 +1,99 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 1 + load_format: dummy + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + distributed_executor_backend: "mp" + hf_config_name: talker_config + load_format: dummy + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 100 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + load_format: dummy + async_scheduling: false + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 200 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/tests/e2e/offline_inference/stage_configs/rocm/qwen2_5_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/rocm/qwen2_5_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474df5e7968b7513c12515f4fa4fed5f67939d42 --- /dev/null +++ b/tests/e2e/offline_inference/stage_configs/rocm/qwen2_5_omni_ci.yaml @@ -0,0 +1,105 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. + +# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090). +# This config is optimized for CI e2e tests. +stage_args: + - stage_id: 0 + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 896 + max_num_batched_tokens: 896 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + skip_mm_profiling: true + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 896 + max_num_batched_tokens: 896 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + skip_mm_profiling: true + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + - stage_id: 2 + runtime: + process: true + devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/tests/e2e/offline_inference/stage_configs/rocm/qwen3_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/rocm/qwen3_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..477e6e59f2921ccd030ca6dbedfb79931a89a04b --- /dev/null +++ b/tests/e2e/offline_inference/stage_configs/rocm/qwen3_omni_ci.yaml @@ -0,0 +1,99 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 1 + load_format: dummy + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + distributed_executor_backend: "mp" + hf_config_name: talker_config + load_format: dummy + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 100 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + load_format: dummy + async_scheduling: false + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 200 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py new file mode 100644 index 0000000000000000000000000000000000000000..9aef47a74c2f2f38fef13ce43bc9e9e5356e2661 --- /dev/null +++ b/tests/e2e/offline_inference/test_bagel_text2img.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +End-to-end test for Bagel text2img generation. + +This test validates that the Bagel model generates images that match +expected reference pixel values within a ±5 tolerance. + +Equivalent to running: + python3 examples/offline_inference/bagel/end2end.py \ + --prompts "A futuristic city skyline at twilight, cyberpunk style" \ + --modality text2img --step 15 +""" + +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +import signal +import socket +import subprocess +import tempfile +import time +from pathlib import Path +from typing import Any + +import pytest +from PIL import Image + +from tests.utils import hardware_test +from vllm_omni.entrypoints.omni import Omni + +# Reference pixel data extracted from the known-good output image +# Each entry contains (x, y) position and expected (R, G, B) values +# "Generated with seed=52, num_inference_steps=15, +# prompt='A futuristic city skyline at twilight, cyberpunk style'" +REFERENCE_PIXELS = [ + {"position": (100, 100), "rgb": (68, 107, 134)}, + {"position": (400, 50), "rgb": (95, 139, 166)}, + {"position": (700, 100), "rgb": (99, 122, 151)}, + {"position": (150, 400), "rgb": (111, 125, 153)}, + {"position": (512, 512), "rgb": (97, 107, 131)}, + {"position": (700, 400), "rgb": (48, 64, 98)}, + {"position": (100, 700), "rgb": (79, 63, 84)}, + {"position": (400, 700), "rgb": (40, 58, 79)}, + {"position": (700, 700), "rgb": (60, 75, 103)}, + {"position": (256, 256), "rgb": (97, 128, 156)}, +] + +# Maximum allowed difference per color channel +PIXEL_TOLERANCE = 5 + +# Default test prompt +DEFAULT_PROMPT = "<|im_start|>A futuristic city skyline at twilight, cyberpunk style<|im_end|>" + + +def _find_free_port() -> int: + """Find and return a free ephemeral port by binding to port 0.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + +def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_steps: int = 15) -> list: + """Configure sampling parameters for Bagel text2img generation. + + Args: + omni: The Omni instance to get default params from. + max_tokens: Maximum tokens for the first stage. + num_inference_steps: Number of inference steps for the diffusion stage. + + Returns: + Configured sampling params list. + """ + params_list = omni.default_sampling_params_list + params_list[0].max_tokens = max_tokens # type: ignore + if len(params_list) > 1: + params_list[1].num_inference_steps = num_inference_steps # type: ignore + return params_list + + +def _extract_generated_image(omni_outputs: list) -> Image.Image | None: + """Extract the generated image from Omni outputs. + + Args: + omni_outputs: List of outputs from omni.generate(). + + Returns: + The first generated PIL Image, or None if no image found. + """ + for req_output in omni_outputs: + if images := getattr(req_output, "images", None): + return images[0] + if hasattr(req_output, "request_output") and req_output.request_output: + for stage_out in req_output.request_output: + if hasattr(stage_out, "images") and stage_out.images: + return stage_out.images[0] + return None + + +def _validate_pixels( + image: Image.Image, + reference_pixels: list[dict[str, Any]] = REFERENCE_PIXELS, + tolerance: int = PIXEL_TOLERANCE, +) -> None: + """Validate that image pixels match expected reference values. + + Args: + image: The PIL Image to validate. + reference_pixels: List of dicts with 'position' (x, y) and 'rgb' (R, G, B). + tolerance: Maximum allowed difference per color channel. + + Raises: + AssertionError: If any pixel differs beyond tolerance. + """ + for ref in reference_pixels: + x, y = ref["position"] + expected = ref["rgb"] + actual = image.getpixel((x, y))[:3] + assert all(abs(a - e) <= tolerance for a, e in zip(actual, expected)), ( + f"Pixel mismatch at ({x}, {y}): expected {expected}, got {actual}" + ) + + +def _generate_bagel_image(omni: Omni, prompt: str = DEFAULT_PROMPT) -> Image.Image: + """Generate an image using Bagel model with configured parameters. + + Args: + omni: The Omni instance to use for generation. + prompt: The text prompt for image generation. + + Returns: + The generated PIL Image. + + Raises: + AssertionError: If no image is generated or size is incorrect. + """ + params_list = _configure_sampling_params(omni) + + omni_outputs = list( + omni.generate( + prompts=[{"prompt": prompt, "modalities": ["image"]}], + sampling_params_list=params_list, + ) + ) + + generated_image = _extract_generated_image(omni_outputs) + assert generated_image is not None, "No images generated" + assert generated_image.size == (1024, 1024), f"Expected 1024x1024, got {generated_image.size}" + + return generated_image + + +@pytest.mark.core_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100"}) +def test_bagel_text2img_shared_memory_connector(): + """Test Bagel text2img with shared memory connector.""" + config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml") + omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300) + + try: + generated_image = _generate_bagel_image(omni) + _validate_pixels(generated_image) + finally: + omni.close() + + +def _wait_for_port(host: str, port: int, timeout: int = 30) -> bool: + """Wait for a port to become available. + + Args: + host: The host address. + port: The port number. + timeout: Maximum seconds to wait. + + Returns: + True if port becomes available, False otherwise. + """ + for _ in range(timeout): + try: + with socket.create_connection((host, port), timeout=1): + return True + except (TimeoutError, ConnectionRefusedError): + time.sleep(1) + return False + + +def _cleanup_mooncake_processes(timeout_secs: int = 5) -> None: + """Clean up any existing mooncake_master processes. + + Args: + timeout_secs: Maximum seconds to wait for graceful termination. + """ + subprocess.run( + ["pkill", "-f", "mooncake_master"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + start_time = time.time() + while time.time() - start_time < timeout_secs: + result = subprocess.run( + ["pgrep", "-f", "mooncake_master"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + if result.returncode != 0: + break + time.sleep(0.5) + else: + subprocess.run( + ["pkill", "-9", "-f", "mooncake_master"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + time.sleep(1) + + +def _load_mooncake_config(host: str, rpc_port: int, http_port: int) -> str: + """Load Mooncake config from YAML and substitute placeholders. + + Args: + host: Mooncake host address. + rpc_port: RPC port for Mooncake master. + http_port: HTTP metadata server port. + + Returns: + Path to the temporary config file with substituted values. + """ + config_path = str(Path(__file__).parent / "stage_configs" / "bagel_mooncake_ci.yaml") + with open(config_path) as f: + config_content = f.read() + + # Substitute placeholders + config_content = config_content.replace("${MOONCAKE_HOST}", host) + config_content = config_content.replace("${MOONCAKE_RPC_PORT}", str(rpc_port)) + config_content = config_content.replace("${MOONCAKE_HTTP_PORT}", str(http_port)) + + # Write to temp file + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + temp_file.write(config_content) + temp_file.close() + return temp_file.name + + +@pytest.mark.core_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100"}) +def test_bagel_text2img_mooncake_connector(): + """Test Bagel text2img with Mooncake connector for inter-stage communication.""" + MOONCAKE_HOST = "127.0.0.1" + MOONCAKE_RPC_PORT = _find_free_port() + MOONCAKE_HTTP_PORT = _find_free_port() + MOONCAKE_METRICS_PORT = _find_free_port() + + mooncake_master_proc = None + temp_config_file = None + omni = None + + try: + _cleanup_mooncake_processes() + + # Start mooncake_master + mooncake_master_proc = subprocess.Popen( + [ + "mooncake_master", + f"--rpc_port={MOONCAKE_RPC_PORT}", + "--enable_http_metadata_server=true", + "--http_metadata_server_host=0.0.0.0", + f"--http_metadata_server_port={MOONCAKE_HTTP_PORT}", + f"--metrics_port={MOONCAKE_METRICS_PORT}", + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid, + ) + + assert _wait_for_port(MOONCAKE_HOST, MOONCAKE_RPC_PORT), "mooncake_master failed to start" + + # Create temp config and initialize Omni + temp_config_file = _load_mooncake_config( + host=MOONCAKE_HOST, + rpc_port=MOONCAKE_RPC_PORT, + http_port=MOONCAKE_HTTP_PORT, + ) + + omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=temp_config_file, stage_init_timeout=300) + + generated_image = _generate_bagel_image(omni) + _validate_pixels(generated_image) + + finally: + if omni: + omni.close() + if temp_config_file: + try: + os.unlink(temp_config_file) + except OSError: + pass + if mooncake_master_proc: + try: + os.killpg(os.getpgid(mooncake_master_proc.pid), signal.SIGKILL) + except OSError: + pass diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..18d3988e04201c1f09b9b72895dd0e48cbc9df28 --- /dev/null +++ b/tests/e2e/offline_inference/test_cache_dit.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +System test for cache-dit backend. + +This test verifies that cache-dit acceleration works correctly with diffusion models. +It uses minimal settings to keep test time short for CI. +""" + +import os +import sys +from pathlib import Path + +import pytest +import torch + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni +from vllm_omni.outputs import OmniRequestOutput + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +# Use random weights model for testing +models = ["riverclouds/qwen_image_random"] + + +@pytest.mark.parametrize("model_name", models) +def test_cache_dit(model_name: str): + """Test cache-dit backend with diffusion model.""" + # Configure cache-dit with minimal settings for fast testing + cache_config = { + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 2, # Minimal warmup for fast test + "residual_diff_threshold": 0.24, + "max_continuous_cached_steps": 3, + } + m = None + try: + m = Omni( + model=model_name, + cache_backend="cache_dit", + cache_config=cache_config, + ) + + # Use minimal settings for fast testing + height = 256 + width = 256 + num_inference_steps = 4 # Minimal steps for fast test + + outputs = m.generate( + "a photo of a cat sitting on a laptop keyboard", + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, # Single output for speed + ), + ) + # Extract images from request_output[0]['images'] + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out.images + + # Verify generation succeeded + assert images is not None + assert len(images) == 1 + # Check image size + assert images[0].width == width + assert images[0].height == height + except Exception as e: + print(f"Test failed with error: {e}") + raise + finally: + if m is not None and hasattr(m, "close"): + m.close() diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..35e106df81a5883a52b439415d74e0009036fa14 --- /dev/null +++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py @@ -0,0 +1,61 @@ +import sys +from pathlib import Path + +import pytest +import torch +from vllm.distributed.parallel_state import cleanup_dist_env_and_memory + +from tests.utils import GPUMemoryMonitor +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +models = ["riverclouds/qwen_image_random"] + + +def inference(model_name: str, offload: bool = True): + current_omni_platform.empty_cache() + device_index = torch.cuda.current_device() + monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + m = Omni(model=model_name, enable_cpu_offload=offload) + torch.cuda.reset_peak_memory_stats(device=device_index) + height = 256 + width = 256 + + m.generate( + "a photo of a cat sitting on a laptop keyboard", + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=9, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + ), + ) + peak = monitor.peak_used_mb + monitor.stop() + + return peak + + +@pytest.mark.skipif(current_omni_platform.is_npu() or current_omni_platform.is_rocm(), reason="Hardware not supported") +@pytest.mark.parametrize("model_name", models) +def test_cpu_offload_diffusion_model(model_name: str): + try: + no_offload_peak_memory = inference(model_name, offload=False) + cleanup_dist_env_and_memory() + offload_peak_memory = inference(model_name, offload=True) + except Exception: + pytest.fail("Inference failed") + print(f"Offload peak memory: {offload_peak_memory} MB") + print(f"No offload peak memory: {no_offload_peak_memory} MB") + assert offload_peak_memory + 2500 < no_offload_peak_memory, ( + f"Offload peak memory {offload_peak_memory} MB should be less than no offload peak memory {no_offload_peak_memory} MB" + ) diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..87a9e0a9e5f5fba020360562db6c1ac563773427 --- /dev/null +++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py @@ -0,0 +1,110 @@ +import sys +from pathlib import Path + +import pytest +import torch +from vllm.distributed.parallel_state import cleanup_dist_env_and_memory + +from tests.utils import GPUMemoryMonitor +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +# Models to test and expected saved memory in MB, correspondingly +MODELS_SAVED_MEMORY_MB = {"riverclouds/qwen_image_random": 4500} + + +def run_inference( + model_name: str, + layerwise_offload: bool = False, + num_gpu_layers: int = 1, + num_inference_steps: int = 3, +) -> float: + # For now, only support on GPU, so apply torch.cuda operations here + # NPU / ROCm platforms are expected to be detected and skipped this test function + torch.cuda.empty_cache() + device_index = torch.cuda.current_device() + monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + + m = Omni( + model=model_name, + enable_layerwise_offload=layerwise_offload, + layerwise_num_gpu_layers=num_gpu_layers, + boundary_ratio=0.875, + flow_shift=5.0, + ) + + torch.cuda.reset_peak_memory_stats(device=device_index) + + # Refer to tests/e2e/offline_inference/test_t2v_model.py + # Use minimal settings for testing + height = 480 + width = 640 + num_frames = 5 + + m.generate( + "A cat sitting on a table", + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=torch.Generator("cuda").manual_seed(42), + guidance_scale=1.0, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + ), + ) + + peak = monitor.peak_used_mb + monitor.stop() + + return peak + + +@pytest.mark.skipif(current_omni_platform.is_npu() or current_omni_platform.is_rocm(), reason="Hardware not supported") +@pytest.mark.parametrize("model_name", MODELS_SAVED_MEMORY_MB.keys()) +def test_layerwise_offload_diffusion_model(model_name: str): + """Test that layerwise offloading reduces GPU memory usage. + + This test verifies that layerwise offloading significantly reduces peak + GPU memory usage compared to loading the entire model on GPU. The layerwise + offloader keeps only a single transformer block on GPU at a time, with + prefetching for compute-memory overlap. + """ + try: + # Run without layerwise offloading (baseline) + no_offload_peak_memory = run_inference(model_name, layerwise_offload=False) + cleanup_dist_env_and_memory() + + # Run with layerwise offloading (1 layer on device) + layerwise_offload_peak_memory = run_inference(model_name, layerwise_offload=True, num_gpu_layers=1) + cleanup_dist_env_and_memory() + + # Run with 2 layers on device + layerwise_offload_two_layers_peak = run_inference(model_name, layerwise_offload=True, num_gpu_layers=2) + except Exception: + pytest.fail("Inference failed") + + print(f"Layerwise offload peak memory (1 GPU layer): {layerwise_offload_peak_memory} MB") + print(f"Layerwise offload peak memory (2 GPU layers): {layerwise_offload_two_layers_peak} MB") + print(f"No offload peak memory: {no_offload_peak_memory} MB") + + # Verify that layerwise offloading significantly reduces memory usage + # Passes only if the actual savings exceeds the expected savings + assert layerwise_offload_peak_memory + MODELS_SAVED_MEMORY_MB[model_name] < no_offload_peak_memory, ( + f"Layerwise offload peak memory {layerwise_offload_peak_memory} MB " + f"should be significantly less than no offload peak memory {no_offload_peak_memory} MB" + ) + + # Verify that 2 GPU layers uses more memory than 1 GPU layer + # But not excessively more (should be a reasonable increase) + assert layerwise_offload_peak_memory < layerwise_offload_two_layers_peak, ( + f"1 GPU layer peak {layerwise_offload_peak_memory} MB should be < " + f"2 GPU layers peak {layerwise_offload_two_layers_peak} MB" + ) diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..2465bc6071b6570c84517f18819f851ca0e58cc7 --- /dev/null +++ b/tests/e2e/offline_inference/test_diffusion_lora.py @@ -0,0 +1,138 @@ +import json +import os +import sys +from pathlib import Path + +import pytest +import torch +from safetensors.torch import save_file + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + + +# This test is specific to Z-Image LoRA behavior. Keep it focused on a single +# model to reduce runtime and avoid extra downloads. +models = ["Tongyi-MAI/Z-Image-Turbo"] + + +@pytest.mark.parametrize("model_name", models) +def test_diffusion_model(model_name: str, tmp_path: Path): + def _extract_images(outputs: list[OmniRequestOutput]): + if not outputs: + raise ValueError("Empty outputs from Omni.generate()") + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + return req_out.images + + def _write_zimage_lora(adapter_dir: Path) -> str: + adapter_dir.mkdir(parents=True, exist_ok=True) + + # Z-Image transformer uses dim=3840 by default (see ZImageTransformer2DModel). + dim = 3840 + module_name = "transformer.layers.0.attention.to_qkv" + rank = 1 + lora_a = torch.zeros((rank, dim), dtype=torch.float32) + lora_a[0, 0] = 1.0 + + # QKVParallelLinear packs (Q, K, V). With tp=1 and n_kv_heads==n_heads in Z-Image, + # each slice is `dim`, so total out dim is `3 * dim`. + lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) + # Apply a visible delta to the Q slice only to keep the perturbation bounded. + lora_b[:dim, 0] = 0.1 + + save_file( + { + f"base_model.model.{module_name}.lora_A.weight": lora_a, + f"base_model.model.{module_name}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps( + { + "r": rank, + "lora_alpha": rank, + "target_modules": [module_name], + } + ), + encoding="utf-8", + ) + return str(adapter_dir) + + m = Omni(model=model_name) + try: + # high resolution may cause OOM on L4 + height = 256 + width = 256 + prompt = "a photo of a cat sitting on a laptop keyboard" + + outputs = m.generate( + prompt, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + ), + ) + images = _extract_images(outputs) + + assert len(images) == 1 + # check image size + assert images[0].width == width + assert images[0].height == height + + # Real LoRA E2E: generate again with a real on-disk PEFT adapter and + # verify that output changes. + if model_name == "Tongyi-MAI/Z-Image-Turbo": + from vllm_omni.lora.request import LoRARequest + from vllm_omni.lora.utils import stable_lora_int_id + + lora_dir = _write_zimage_lora(tmp_path / "zimage_lora") + lora_request = LoRARequest( + lora_name="test", + lora_int_id=stable_lora_int_id(lora_dir), + lora_path=lora_dir, + ) + outputs_lora = m.generate( + prompt, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + lora_request=lora_request, + lora_scale=2.0, + ), + ) + images_lora = _extract_images(outputs_lora) + assert len(images_lora) == 1 + assert images_lora[0].width == width + assert images_lora[0].height == height + + import numpy as np + + diff = np.abs(np.array(images[0], dtype=np.int16) - np.array(images_lora[0], dtype=np.int16)).mean() + assert diff > 0.0 + finally: + m.close() diff --git a/tests/e2e/offline_inference/test_ovis_image.py b/tests/e2e/offline_inference/test_ovis_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f1bc73817d3a615c187bf944eea769f119473804 --- /dev/null +++ b/tests/e2e/offline_inference/test_ovis_image.py @@ -0,0 +1,290 @@ +""" +Tests for Ovis Image model pipeline. + +Strategy: +1. `mock_dependencies` fixture mocks heavy external components (VAE, Scheduler, TextEncoder) + to allow fast testing of the pipeline logic without downloading weights. + - Mocks are configured to return tensors on the correct device. + - Transformer is mocked dynamically to return random noise of correct shape. + +2. `test_real_transformer_init_and_forward` tests the actual `OvisImageTransformer2DModel` + initialization and forward pass with a small configuration to ensure code coverage + and correctness of the model definition itself, independent of the pipeline mocks. +""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig + +# Mock the OvisImageTransformer2DModel to avoid complex init if needed, +# or let it run if it's lightweight. It's likely not lightweight. +# Better to mock the transformer forwarding to return random noise. +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.models.ovis_image.pipeline_ovis_image import OvisImagePipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + +@pytest.fixture +def mock_dependencies(monkeypatch): + """ + Mock external dependencies to avoid loading real models. + """ + device = get_local_device() + + # Mock Tokenizer + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = MagicMock( + input_ids=torch.zeros((1, 50), dtype=torch.long, device=device), + attention_mask=torch.ones((1, 50), dtype=torch.long, device=device), + ) + mock_tokenizer.apply_chat_template.return_value = "dummy prompt" + mock_tokenizer.model_max_length = 1024 + + # Mock Text Encoder + mock_text_encoder = MagicMock() + mock_text_encoder.dtype = torch.float32 + # Output of text encoder must be on the same device as inputs (which are moved to execution_device) + mock_text_encoder.return_value.last_hidden_state = torch.randn(1, 50, 32, device=device) + + # Mock VAE + mock_vae = MagicMock() + mock_vae.config.block_out_channels = [128, 256, 512, 512] # Scale factor 8 + mock_vae.config.scale_factor_temporal = 1 + mock_vae.config.scale_factor_spatial = 8 + mock_vae.config.scaling_factor = 0.18215 + mock_vae.config.shift_factor = 0.0 + # Decode return value + mock_vae.decode.return_value = [torch.randn(1, 3, 128, 128, device=device)] + # Ensure .to() returns self so configuration persists + mock_vae.to.return_value = mock_vae + + # Mock Scheduler + mock_scheduler = MagicMock() + mock_scheduler.config = MagicMock() + # Timesteps on device to match latents during denoising loop interaction if needed + mock_scheduler.timesteps = torch.tensor([1.0, 0.5, 0.0], device=device) + mock_scheduler.set_timesteps.return_value = None + + # Make step return dynamic based on input sample shape + def mock_scheduler_step(model_output, timestep, sample, **kwargs): + # sample is the latents, should be preserved + return (torch.randn_like(sample),) + + mock_scheduler.step.side_effect = mock_scheduler_step + + module_path = "vllm_omni.diffusion.models.ovis_image.pipeline_ovis_image" + + monkeypatch.setattr(f"{module_path}.Qwen2TokenizerFast.from_pretrained", lambda *a, **k: mock_tokenizer) + monkeypatch.setattr(f"{module_path}.Qwen3Model.from_pretrained", lambda *a, **k: mock_text_encoder) + monkeypatch.setattr(f"{module_path}.AutoencoderKL.from_pretrained", lambda *a, **k: mock_vae) + monkeypatch.setattr( + f"{module_path}.FlowMatchEulerDiscreteScheduler.from_pretrained", lambda *a, **k: mock_scheduler + ) + + return { + "tokenizer": mock_tokenizer, + "text_encoder": mock_text_encoder, + "vae": mock_vae, + "scheduler": mock_scheduler, + "device": device, + } + + +@pytest.fixture +def ovis_pipeline(mock_dependencies, monkeypatch): + """ + Creates an OvisImagePipeline instance with mocked components. + """ + # Create config + tf_config = TransformerConfig( + params={ + "in_channels": 4, + "out_channels": 4, + "sample_size": 32, + "patch_size": 2, + "num_attention_heads": 4, + "attention_head_dim": 8, + "num_layers": 1, + "caption_channels": 32, + } + ) + + od_config = OmniDiffusionConfig( + model="dummy-ovis", + tf_model_config=tf_config, + dtype=torch.float32, + num_gpus=1, + ) + + # Mock Transformer Layer separately to avoid full init + # We patch OvisImageTransformer2DModel class in the module + mock_transformer_cls = MagicMock() + mock_transformer_instance = MagicMock() + mock_transformer_instance.dtype = torch.float32 + mock_transformer_instance.in_channels = 16 # Must be 16 so num_channel_latents=4, packed=16 + # Forward return: noise prediction + + def mock_forward(hidden_states, *args, **kwargs): + # hidden_states shape: (B, SeqLen, Channels) + return (torch.randn_like(hidden_states),) + + mock_transformer_instance.forward.side_effect = mock_forward + # Also make the instance itself callable to mimic __call__ + mock_transformer_instance.side_effect = mock_forward + + mock_transformer_cls.return_value = mock_transformer_instance + + monkeypatch.setattr( + "vllm_omni.diffusion.models.ovis_image.pipeline_ovis_image.OvisImageTransformer2DModel", mock_transformer_cls + ) + + # Initialize pipeline + # We use a dummy model path check override + with patch("os.path.exists", return_value=True): + pipeline = OvisImagePipeline(od_config=od_config) + + return pipeline + + +def test_interface_compliance(ovis_pipeline): + """Verify methods required by vllm-omni framework.""" + assert hasattr(ovis_pipeline, "load_weights") + assert hasattr(ovis_pipeline, "scheduler") + assert hasattr(ovis_pipeline, "transformer") + assert hasattr(ovis_pipeline, "text_encoder") + # assert hasattr(ovis_pipeline, "vae") # Ovis uses VAE + + +def test_basic_generation(ovis_pipeline): + """Test the forward pass logic.""" + # Setup request + req = OmniDiffusionRequest( + prompts=["A photo of a cat"], + sampling_params=OmniDiffusionSamplingParams( + height=256, + width=256, + num_inference_steps=2, + guidance_scale=1.0, + ), + ) + + output = ovis_pipeline(req) + + assert output is not None + assert output.output is not None + # Output should be a tensor from mocked VAE decode [torch.randn(1, 3, 128, 128)] + assert isinstance(output.output, torch.Tensor) + assert output.output.shape == (1, 3, 128, 128) + + # Check that transformer was called + assert ovis_pipeline.transformer.call_count > 0 + + +def test_guidance_scale(ovis_pipeline): + """Test that classifier-free guidance path is taken when scale > 1.0.""" + req = OmniDiffusionRequest( + prompts=[ + { + "prompt": "A photo of a cat", + "negative_prompt": "bad quality", + } + ], + sampling_params=OmniDiffusionSamplingParams( + height=256, + width=256, + num_inference_steps=1, + guidance_scale=2.0, # Trigger CFG + ), + ) + + ovis_pipeline(req) + assert ovis_pipeline.transformer.call_count >= 2 + + +def test_resolution_check(ovis_pipeline): + """Test resolution divisible validation logic if present.""" + # Pass odd resolution + req = OmniDiffusionRequest( + prompts=["test"], + sampling_params=OmniDiffusionSamplingParams( + height=250, # Not divisible by 16 (8*2) + width=250, + ), + ) + + # Should warn but proceed (as per code I read earlier) or resize? + # The code had `logger.warning(...)` + + output = ovis_pipeline(req) + assert output is not None + + +def test_real_transformer_init_and_forward(): + """Test the real OvisImageTransformer2DModel initialization and forward pass for coverage.""" + from unittest.mock import patch + + from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel + + device = get_local_device() + tf_config = TransformerConfig( + params={ + "patch_size": 2, + "in_channels": 16, + "out_channels": 16, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "axes_dims_rope": (4, 4, 4), + } + ) + + od_config = OmniDiffusionConfig(model="dummy-ovis", tf_model_config=tf_config, dtype=torch.bfloat16, num_gpus=1) + torch.set_default_dtype(torch.bfloat16) + + # Mock distributed state for QKVParallelLinear initialization + # We patch get_tp_group because get_tensor_model_parallel_rank calls it and asserts _TP is not None + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 1 + + with patch("vllm.distributed.parallel_state.get_tp_group", return_value=mock_group): + # Initialize real model + model = OvisImageTransformer2DModel( + od_config=od_config, + patch_size=1, + in_channels=16, + out_channels=16, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=32, + axes_dims_rope=(2, 2, 4), + ).to(device) + + # Create dummy inputs + B, Seq, C = 1, 16, 16 + hidden_states = torch.randn(B, Seq, C, device=device) + encoder_hidden_states = torch.randn(B, 10, 32, device=device) # joint_attention_dim=32 + timestep = torch.tensor([1], device=device) + img_ids = torch.zeros(Seq, 3, device=device) + txt_ids = torch.zeros(10, 3, device=device) + + # Run forward + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + img_ids=img_ids, + txt_ids=txt_ids, + return_dict=False, + ) + + assert output is not None + assert isinstance(output, tuple) + assert output[0].shape == hidden_states.shape diff --git a/tests/e2e/offline_inference/test_qwen2_5_omni.py b/tests/e2e/offline_inference/test_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..af9d793c1dcfbecdbb2e7f1f9c5ac457a81c231f --- /dev/null +++ b/tests/e2e/offline_inference/test_qwen2_5_omni.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E tests for Qwen2.5-Omni model with mixed modality inputs and audio output. +""" + +from pathlib import Path + +import pytest +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.multimodal.image import convert_image_mode + +from vllm_omni.platforms import current_omni_platform + +from .conftest import OmniRunner +from .utils import create_new_process_for_each_test + +models = ["Qwen/Qwen2.5-Omni-3B"] + +# CI stage config optimized for 24GB GPU (L4/RTX3090) or NPU +if current_omni_platform.is_npu(): + stage_config = str(Path(__file__).parent / "stage_configs" / "npu" / "qwen2_5_omni_ci.yaml") +elif current_omni_platform.is_rocm(): + # ROCm stage config optimized for MI325 GPU + stage_config = str(Path(__file__).parent / "stage_configs" / "rocm" / "qwen2_5_omni_ci.yaml") +else: + stage_config = str(Path(__file__).parent / "stage_configs" / "qwen2_5_omni_ci.yaml") + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models] + + +@pytest.mark.core_model +@pytest.mark.parametrize("test_config", test_params) +@create_new_process_for_each_test("spawn") +def test_mixed_modalities_to_audio(omni_runner: type[OmniRunner], test_config: tuple[str, str]) -> None: + """Test processing audio, image, and video together, generating audio output.""" + model, stage_config_path = test_config + with omni_runner(model, seed=42, stage_configs_path=stage_config_path) as runner: + # Prepare multimodal inputs + question = "What is recited in the audio? What is in this image? Describe the video briefly." + audio = AudioAsset("mary_had_lamb").audio_and_sample_rate + audio = (audio[0][: 16000 * 5], audio[1]) # Trim to first 5 seconds + image = convert_image_mode(ImageAsset("cherry_blossom").pil_image.resize((128, 128)), "RGB") + if not VLLM_USE_MODELSCOPE: + video = VideoAsset(name="baby_reading", num_frames=4).np_ndarrays + else: + # modelscope can't access raushan-testing-hf/videos-test, skip video input temporarily + video = None + + outputs = runner.generate_multimodal( + prompts=question, + audios=audio, + images=image, + videos=video, + ) + + # Find and verify text output (thinker stage) + text_output = None + output_count = 0 + for stage_output in outputs: + if stage_output.final_output_type == "text": + text_output = stage_output + output_count += 1 + break + assert output_count > 0 + + assert text_output is not None + assert len(text_output.request_output) > 0 + text_content = text_output.request_output[0].outputs[0].text + assert text_content is not None + assert len(text_content.strip()) > 0 + + # Find and verify audio output (code2wav stage) + audio_output = None + output_count = 0 + for stage_output in outputs: + if stage_output.final_output_type == "audio": + audio_output = stage_output + output_count += 1 + break + assert output_count > 0 + + assert audio_output is not None + assert len(audio_output.request_output) > 0 + + # Verify audio tensor exists and has content + audio_tensor = audio_output.request_output[0].outputs[0].multimodal_output["audio"] + assert audio_tensor is not None + assert audio_tensor.numel() > 0 + + +@pytest.mark.core_model +@pytest.mark.parametrize("test_config", test_params) +@create_new_process_for_each_test("spawn") +def test_mixed_modalities_to_text_only(omni_runner: type[OmniRunner], test_config: tuple[str, str]) -> None: + """Test processing audio, image, and video together, generating audio output.""" + model, stage_config_path = test_config + with omni_runner(model, seed=42, stage_configs_path=stage_config_path) as runner: + # Prepare multimodal inputs + question = "What is recited in the audio? What is in this image? Describe the video briefly." + audio = AudioAsset("mary_had_lamb").audio_and_sample_rate + audio = (audio[0][: 16000 * 5], audio[1]) # Trim to first 5 seconds + image = convert_image_mode(ImageAsset("cherry_blossom").pil_image.resize((128, 128)), "RGB") + video = VideoAsset(name="baby_reading", num_frames=4).np_ndarrays + modalities = ["text"] + + outputs = runner.generate_multimodal( + prompts=question, + audios=audio, + images=image, + videos=video, + modalities=modalities, + ) + + # Find and verify text output (thinker stage) + text_output = None + output_count = 0 + for stage_output in outputs: + assert stage_output.final_output_type != "audio" + if stage_output.final_output_type == "text": + text_output = stage_output + output_count += 1 + break + assert output_count > 0 + + assert text_output is not None + assert len(text_output.request_output) > 0 + text_content = text_output.request_output[0].outputs[0].text + assert text_content is not None + assert len(text_content.strip()) > 0 diff --git a/tests/e2e/offline_inference/test_qwen3_omni.py b/tests/e2e/offline_inference/test_qwen3_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5c66fe34805cf9644298716e5e822f356f4b61 --- /dev/null +++ b/tests/e2e/offline_inference/test_qwen3_omni.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E offline tests for Omni model with video input and audio output. +""" + +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +from pathlib import Path + +import pytest +from vllm.assets.video import VideoAsset + +from vllm_omni.platforms import current_omni_platform + +from .conftest import OmniRunner + +models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] + +# CI stage config for 2xH100-80G GPUs or AMD GPU MI325 +if current_omni_platform.is_rocm(): + # ROCm stage config optimized for MI325 GPU + stage_configs = [str(Path(__file__).parent / "stage_configs" / "rocm" / "qwen3_omni_ci.yaml")] +else: + stage_configs = [str(Path(__file__).parent / "stage_configs" / "qwen3_omni_ci.yaml")] + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +@pytest.mark.parametrize("test_config", test_params) +def test_video_to_audio(omni_runner: type[OmniRunner], test_config) -> None: + """Test processing video, generating audio output.""" + model, stage_config_path = test_config + with omni_runner(model, seed=42, stage_configs_path=stage_config_path, stage_init_timeout=300) as runner: + # Prepare inputs + question = "Describe the video briefly." + video = VideoAsset(name="baby_reading", num_frames=4).np_ndarrays + + outputs = runner.generate_multimodal( + prompts=question, + videos=video, + ) + + # Find and verify text output (thinker stage) + text_output = None + output_count = 0 + for stage_output in outputs: + if stage_output.final_output_type == "text": + text_output = stage_output + output_count += 1 + break + + assert output_count > 0 + assert text_output is not None + assert len(text_output.request_output) > 0 + text_content = text_output.request_output[0].outputs[0].text + assert text_content is not None + assert len(text_content.strip()) > 0 + + # Find and verify audio output (code2wav stage) + audio_output = None + output_count = 0 + for stage_output in outputs: + if stage_output.final_output_type == "audio": + audio_output = stage_output + output_count += 1 + break + + assert output_count > 0 + assert audio_output is not None + assert len(audio_output.request_output) > 0 + + # Verify audio tensor exists and has content + audio_tensor = audio_output.request_output[0].outputs[0].multimodal_output["audio"] + assert audio_tensor is not None + assert audio_tensor.numel() > 0 diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7bb561799bfa71a11d0f3ea7d671548158f63d --- /dev/null +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +System test for Sequence Parallel (SP) backends: Ulysses and Ring attention. + +Tests verify that SP inference produces correct outputs compared to baseline. +""" + +import gc +import os +import sys +import time +from pathlib import Path +from typing import NamedTuple + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from PIL import Image + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.platforms import current_omni_platform + +# Test configuration +MODELS = ["riverclouds/qwen_image_random"] +PROMPT = "a photo of a cat sitting on a laptop keyboard" +DEFAULT_HEIGHT = 256 +DEFAULT_WIDTH = 256 +DEFAULT_SEED = 42 +DEFAULT_STEPS = 4 +DIFF_MEAN_THRESHOLD = 2e-2 +DIFF_MAX_THRESHOLD = 2e-1 + + +class InferenceResult(NamedTuple): + """Result of an inference run.""" + + images: list[Image.Image] + elapsed_ms: float + + +def _cleanup_distributed(): + """Clean up distributed environment and GPU resources.""" + if dist.is_initialized(): + dist.destroy_process_group() + + for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK"]: + os.environ.pop(key, None) + + gc.collect() + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() + current_omni_platform.synchronize() + + time.sleep(5) + + +def _diff_metrics(a: Image.Image, b: Image.Image) -> tuple[float, float]: + """Return (mean_abs_diff, max_abs_diff) over RGB pixels in [0, 1].""" + ta = torch.from_numpy(np.asarray(a.convert("RGB"), dtype=np.float32) / 255.0) + tb = torch.from_numpy(np.asarray(b.convert("RGB"), dtype=np.float32) / 255.0) + assert ta.shape == tb.shape, f"Image shapes differ: {ta.shape} vs {tb.shape}" + abs_diff = torch.abs(ta - tb) + return abs_diff.mean().item(), abs_diff.max().item() + + +def _run_inference( + model_name: str, + dtype: torch.dtype, + attn_backend: str, + ulysses_degree: int = 1, + ring_degree: int = 1, + height: int = DEFAULT_HEIGHT, + width: int = DEFAULT_WIDTH, + seed: int = DEFAULT_SEED, + warmup: bool = True, +) -> InferenceResult: + """Run inference with specified configuration. + + Args: + warmup: If True, run one warmup iteration before the timed run. + """ + parallel_config = DiffusionParallelConfig(ulysses_degree=ulysses_degree, ring_degree=ring_degree) + omni = Omni( + model=model_name, + parallel_config=parallel_config, + dtype=dtype, + attention_backend=attn_backend, + ) + + try: + # Warmup run (not timed) + if warmup: + _ = omni.generate( + PROMPT, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=DEFAULT_STEPS, + guidance_scale=0.0, + generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed + 1000), + num_outputs_per_prompt=1, + ), + ) + + # Timed run + start = time.time() + outputs = omni.generate( + PROMPT, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=DEFAULT_STEPS, + guidance_scale=0.0, + generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed), + num_outputs_per_prompt=1, + ), + ) + elapsed_ms = (time.time() - start) * 1000 + + return InferenceResult( + images=outputs[0].request_output[0].images, + elapsed_ms=elapsed_ms, + ) + finally: + omni.close() + _cleanup_distributed() + + +# ============================================================================= +# Correctness & Performance Tests +# ============================================================================= + +# SP configurations: (ulysses_degree, ring_degree, height, width, warmup, is_perf_test) +# - warmup: whether to run warmup for this SP config +# - is_perf_test: whether this is a performance test (show speedup metrics) +SP_CONFIGS = [ + (2, 1, DEFAULT_HEIGHT, DEFAULT_WIDTH, True, True), # Ulysses-2 - performance test + (1, 2, DEFAULT_HEIGHT, DEFAULT_WIDTH, True, True), # Ring-2 - performance test + (2, 2, DEFAULT_HEIGHT, DEFAULT_WIDTH, False, False), # Hybrid - correctness only + (4, 1, 272, 272, False, False), # Ulysses-4 - shape and correctness +] + + +def _get_sp_mode(ulysses_degree: int, ring_degree: int) -> str: + """Get SP mode name for logging.""" + if ulysses_degree > 1 and ring_degree == 1: + return f"ulysses-{ulysses_degree}" + elif ring_degree > 1 and ulysses_degree == 1: + return f"ring-{ring_degree}" + else: + return f"hybrid-{ulysses_degree}x{ring_degree}" + + +@pytest.mark.parametrize("model_name", MODELS) +def test_sp_correctness(model_name: str): + """Test that SP inference produces correct outputs and measure performance. + + Runs baseline once per unique (height, width), then tests all SP configs. + + Note: Run with `pytest -v -s` to see detailed output. + """ + device_count = current_omni_platform.get_device_count() + + # Cache baseline results by (height, width) + # Key: (height, width), Value: (result, warmup_used) + baseline_cache: dict[tuple[int, int], InferenceResult] = {} + + # Collect results for summary + results: list[dict] = [] + + print("\n" + "=" * 70) + print(f"Sequence Parallel Test - Model: {model_name}") + print(f"Available GPUs: {device_count}") + print("=" * 70) + + for ulysses_degree, ring_degree, height, width, sp_warmup, is_perf_test in SP_CONFIGS: + sp_size = ulysses_degree * ring_degree + sp_mode = _get_sp_mode(ulysses_degree, ring_degree) + + if device_count < sp_size: + print(f"\n[{sp_mode}] SKIPPED (requires {sp_size} GPUs)") + continue + + # Determine baseline warmup: only for default size (performance tests) + cache_key = (height, width) + baseline_warmup = height == DEFAULT_HEIGHT and width == DEFAULT_WIDTH + + # Get or compute baseline for this (height, width) + if cache_key not in baseline_cache: + print(f"\n--- Running baseline {height}x{width} (warmup={baseline_warmup}) ---") + baseline = _run_inference( + model_name, + torch.bfloat16, + "sdpa", + height=height, + width=width, + warmup=baseline_warmup, + ) + assert len(baseline.images) == 1 + baseline_cache[cache_key] = baseline + print(f"[baseline] {height}x{width}: {baseline.elapsed_ms:.0f}ms") + else: + baseline = baseline_cache[cache_key] + + # Run SP + print(f"\n--- Running {sp_mode} (warmup={sp_warmup}) ---") + sp_result = _run_inference( + model_name, + torch.bfloat16, + "sdpa", + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + height=height, + width=width, + warmup=sp_warmup, + ) + assert len(sp_result.images) == 1 + + # Compare outputs (correctness) + mean_diff, max_diff = _diff_metrics(baseline.images[0], sp_result.images[0]) + + # Build result entry + result = { + "mode": sp_mode, + "sp_size": sp_size, + "height": height, + "width": width, + "baseline_ms": baseline.elapsed_ms, + "sp_ms": sp_result.elapsed_ms, + "mean_diff": mean_diff, + "max_diff": max_diff, + "is_perf_test": is_perf_test, + } + results.append(result) + + # Output based on test type + if is_perf_test: + speedup = baseline.elapsed_ms / sp_result.elapsed_ms if sp_result.elapsed_ms > 0 else 0 + result["speedup"] = speedup + print( + f"[{sp_mode}] {sp_size} GPUs | " + f"baseline: {baseline.elapsed_ms:.0f}ms, sp: {sp_result.elapsed_ms:.0f}ms, " + f"speedup: {speedup:.2f}x" + ) + else: + print(f"[{sp_mode}] {sp_size} GPUs | sp: {sp_result.elapsed_ms:.0f}ms (correctness only)") + + print(f"[{sp_mode}] diff: mean={mean_diff:.6e}, max={max_diff:.6e}") + + # Assert correctness + assert mean_diff <= DIFF_MEAN_THRESHOLD and max_diff <= DIFF_MAX_THRESHOLD, ( + f"[{sp_mode}] SP output differs from baseline: mean={mean_diff:.6e}, max={max_diff:.6e}" + ) + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"{'Mode':<15} {'GPUs':<6} {'Size':<10} {'Baseline':<12} {'SP':<12} {'Speedup':<10} {'Status'}") + print("-" * 70) + for r in results: + speedup_str = f"{r['speedup']:.2f}x" if r.get("speedup") else "N/A" + baseline_str = f"{r['baseline_ms']:.0f}ms" if r["is_perf_test"] else "N/A" + status = "PASS" if r["mean_diff"] <= DIFF_MEAN_THRESHOLD else "FAIL" + print( + f"{r['mode']:<15} {r['sp_size']:<6} {r['height']}x{r['width']:<5} " + f"{baseline_str:<12} {r['sp_ms']:.0f}ms{'':<7} {speedup_str:<10} {status}" + ) + print("=" * 70) diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py new file mode 100644 index 0000000000000000000000000000000000000000..df2ca5e4283ca13a1dc0cb591abe27a23676dbb8 --- /dev/null +++ b/tests/e2e/offline_inference/test_stable_audio_model.py @@ -0,0 +1,67 @@ +import sys +from pathlib import Path + +import numpy as np +import pytest +import torch + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +# Use random weights model for CI testing (small, no authentication required) +models = ["linyueqian/stable_audio_random"] + + +@pytest.mark.parametrize("model_name", models) +def test_stable_audio_model(model_name: str): + m = Omni(model=model_name) + + # Use minimal settings for testing + # Generate a short 2-second audio clip with minimal inference steps + audio_start_in_s = 0.0 + audio_end_in_s = 2.0 # Short duration for fast testing + sample_rate = 44100 # Stable Audio uses 44100 Hz + + outputs = m.generate( + prompts={ + "prompt": "The sound of a dog barking", + "negative_prompt": "Low quality.", + }, + sampling_params_list=OmniDiffusionSamplingParams( + num_inference_steps=4, # Minimal steps for speed + guidance_scale=7.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + extra_args={ + "audio_start_in_s": audio_start_in_s, + "audio_end_in_s": audio_end_in_s, + }, + ), + ) + + # Extract audio from OmniRequestOutput + assert outputs is not None + first_output = outputs[0] + assert first_output.final_output_type == "image" + assert hasattr(first_output, "request_output") and first_output.request_output + + req_out = first_output.request_output[0] + assert isinstance(req_out, OmniRequestOutput) + assert req_out.final_output_type == "audio" + assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output + audio = req_out.multimodal_output.get("audio") + assert isinstance(audio, np.ndarray) + # audio shape: (batch, channels, samples) + # For stable-audio-open-1.0: sample_rate=44100, so 2 seconds = 88200 samples + assert audio.ndim == 3 + assert audio.shape[0] == 1 # batch size + assert audio.shape[1] == 2 # stereo channels + expected_samples = int((audio_end_in_s - audio_start_in_s) * sample_rate) + assert audio.shape[2] == expected_samples # 88200 samples for 2 seconds diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e7351f1573ce72772f7e5360a0abd2f9f4b30c23 --- /dev/null +++ b/tests/e2e/offline_inference/test_t2i_model.py @@ -0,0 +1,76 @@ +import os +import sys +from pathlib import Path + +import pytest +import torch + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + + +models = ["Tongyi-MAI/Z-Image-Turbo", "riverclouds/qwen_image_random"] + +# Modelscope can't find riverclouds/qwen_image_random +# TODO: When NPU support is ready, remove this branch. +if current_omni_platform.is_npu(): + models = ["Tongyi-MAI/Z-Image-Turbo", "Qwen/Qwen-Image"] +elif current_omni_platform.is_rocm(): + # TODO: When ROCm support is ready, remove this branch. + # vLLM V0.11.0 has issues running riverclouds/qwen_image_random + # on ROCm + models = ["Tongyi-MAI/Z-Image-Turbo"] + + +@pytest.mark.parametrize("model_name", models) +def test_diffusion_model(model_name: str): + m = None + try: + m = Omni(model=model_name) + # high resolution may cause OOM on L4 + height = 256 + width = 256 + outputs = m.generate( + "a photo of a cat sitting on a laptop keyboard", + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=2, + ), + ) + # Extract images from request_output[0]['images'] + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out.images + + assert len(images) == 2 + # check image size + assert images[0].width == width + assert images[0].height == height + images[0].save("image_output.png") + except Exception as e: + print(f"Test failed with error: {e}") + raise + finally: + if m is not None and hasattr(m, "close"): + m.close() diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a378291acdc0d5192d19f9efff8e29283a253a89 --- /dev/null +++ b/tests/e2e/offline_inference/test_t2v_model.py @@ -0,0 +1,64 @@ +import os +import sys +from pathlib import Path + +import pytest +import torch + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni +from vllm_omni.outputs import OmniRequestOutput + +# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +models = ["Wan-AI/Wan2.2-T2V-A14B-Diffusers"] + + +@pytest.mark.parametrize("model_name", models) +def test_video_diffusion_model(model_name: str): + m = Omni( + model=model_name, + boundary_ratio=0.875, + flow_shift=5.0, + ) + # Use minimal settings for testing + # num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1 + # For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ... + height = 480 + width = 640 + num_frames = 5 + outputs = m.generate( + prompts="A cat sitting on a table", + sampling_params_list=OmniDiffusionSamplingParams( + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=2, + guidance_scale=1.0, + generator=torch.Generator("cuda").manual_seed(42), + ), + ) + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + + frames = req_out.images[0] + + assert frames is not None + assert hasattr(frames, "shape") + # frames shape: (batch, num_frames, height, width, channels) + assert frames.shape[1] == num_frames + assert frames.shape[2] == height + assert frames.shape[3] == width diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6261388192b7f4e8e0db9603bf926532047d5e --- /dev/null +++ b/tests/e2e/offline_inference/test_teacache.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +System test for TeaCache backend. + +This test verifies that TeaCache acceleration works correctly with diffusion models. +It uses minimal settings to keep test time short for CI. +""" + +import os +import sys +from pathlib import Path + +import pytest +import torch + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni +from vllm_omni.outputs import OmniRequestOutput + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +# Use random weights model for testing +models = ["riverclouds/qwen_image_random"] + + +@pytest.mark.parametrize("model_name", models) +def test_teacache(model_name: str): + """Test TeaCache backend with diffusion model.""" + # Configure TeaCache with default settings for fast testing + cache_config = { + "rel_l1_thresh": 0.2, # Default threshold + } + m = None + try: + m = Omni( + model=model_name, + cache_backend="tea_cache", + cache_config=cache_config, + ) + + # Use minimal settings for fast testing + height = 256 + width = 256 + num_inference_steps = 4 # Minimal steps for fast test + + outputs = m.generate( + "a photo of a cat sitting on a laptop keyboard", + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, # Single output for speed + ), + ) + # Extract images from request_output[0]['images'] + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out.images + + # Verify generation succeeded + assert images is not None + assert len(images) == 1 + # Check image size + assert images[0].width == width + assert images[0].height == height + except Exception as e: + print(f"Test failed with error: {e}") + raise + finally: + if m is not None and hasattr(m, "close"): + m.close() diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3b97ec39fbaba7c79205c036c4beeef4661bf6 --- /dev/null +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import sys +import time +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from vllm.distributed.parallel_state import cleanup_dist_env_and_memory + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from tests.utils import GPUMemoryMonitor +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +PROMPT = "a photo of a cat sitting on a laptop keyboard" + + +def _get_zimage_model() -> str: + # Allow overriding the model for local/offline environments. + # Can be either a HuggingFace repo id or a local path. + return os.environ.get("VLLM_TEST_ZIMAGE_MODEL", "Tongyi-MAI/Z-Image-Turbo") + + +def _pil_to_float_rgb_tensor(img: Image.Image) -> torch.Tensor: + """Convert PIL image to float32 RGB tensor in [0, 1] with shape [H, W, 3].""" + arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0 + return torch.from_numpy(arr) + + +def _diff_metrics(a: Image.Image, b: Image.Image) -> tuple[float, float]: + """Return (mean_abs_diff, max_abs_diff) over RGB pixels in [0, 1].""" + ta = _pil_to_float_rgb_tensor(a) + tb = _pil_to_float_rgb_tensor(b) + assert ta.shape == tb.shape, f"Image shapes differ: {ta.shape} vs {tb.shape}" + abs_diff = torch.abs(ta - tb) + return abs_diff.mean().item(), abs_diff.max().item() + + +def _extract_single_image(outputs) -> Image.Image: + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out.images + if images is None or len(images) != 1: + raise ValueError(f"Expected 1 image, got {0 if images is None else len(images)}") + return images[0] + + +def _run_zimage_generate( + *, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int +) -> tuple[Image.Image, float, float]: + torch.cuda.empty_cache() + device_index = torch.cuda.current_device() + monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + + m = Omni( + model=_get_zimage_model(), + parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size), + ) + try: + # NOTE: Omni closes itself when a generate() call is exhausted. + # To avoid measuring teardown time (process shutdown, memory cleanup), + # we measure the latency to produce *subsequent* outputs within a single + # generator run. + # + # This also serves as a warmup: the first output may include extra + # compilation/caching overhead, while later outputs are closer to + # steady-state inference. + num_requests = 4 # 1 warmup + 3 timed + gen = m.generate( + [PROMPT] * num_requests, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + seed=seed, + num_outputs_per_prompt=1, + ), + py_generator=True, + ) + + warmup_output = next(gen) + t_prev = time.perf_counter() + per_request_times_s: list[float] = [] + last_output = warmup_output + for _ in range(num_requests - 1): + last_output = next(gen) + t_now = time.perf_counter() + per_request_times_s.append(t_now - t_prev) + t_prev = t_now + + # Ensure the generator is fully consumed so it can clean up. + for _ in gen: + pass + + median_time_s = float(np.median(per_request_times_s)) + + peak_memory_mb = monitor.peak_used_mb + + return _extract_single_image([last_output]), median_time_s, peak_memory_mb + finally: + monitor.stop() + cleanup_dist_env_and_memory() + + +@pytest.mark.integration +def test_zimage_tensor_parallel_tp2(tmp_path: Path): + if current_omni_platform.is_npu() or current_omni_platform.is_rocm(): + pytest.skip("Z-Image TP e2e test is only supported on CUDA for now.") + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.") + + height = 512 + width = 512 + num_inference_steps = 2 + seed = 42 + + tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate( + tp_size=1, + height=height, + width=width, + num_inference_steps=num_inference_steps, + seed=seed, + ) + tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate( + tp_size=2, + height=height, + width=width, + num_inference_steps=num_inference_steps, + seed=seed, + ) + + tp1_path = tmp_path / "zimage_tp1.png" + tp2_path = tmp_path / "zimage_tp2.png" + tp1_img.save(tp1_path) + tp2_img.save(tp2_path) + + assert tp1_img.width == width and tp1_img.height == height + assert tp2_img.width == width and tp2_img.height == height + + mean_abs_diff, max_abs_diff = _diff_metrics(tp1_img, tp2_img) + mean_threshold = 3e-2 + max_threshold = 5e-1 + print( + "Z-Image TP image diff stats (TP=1 vs TP=2): " + f"mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e}; " + f"thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}; " + f"tp1_img={tp1_path}, tp2_img={tp2_path}" + ) + assert mean_abs_diff <= mean_threshold and max_abs_diff <= max_threshold, ( + f"Image diff exceeded threshold: mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e} " + f"(thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e})" + ) + + print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}") + assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})" + + print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}") + assert tp2_peak_mem < tp1_peak_mem, ( + f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})" + ) diff --git a/tests/e2e/offline_inference/utils.py b/tests/e2e/offline_inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3113599a30589ea97382682e3271d85799bfbdc5 --- /dev/null +++ b/tests/e2e/offline_inference/utils.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import functools +import os +import signal +import subprocess +import sys +import tempfile +from collections.abc import Callable +from contextlib import ExitStack, suppress +from pathlib import Path +from typing import Any, Literal + +import cloudpickle +from typing_extensions import ParamSpec +from vllm.platforms import current_platform + +VLLM_PATH = Path(__file__).parent.parent.parent +"""Path to root of the vLLM repository.""" + + +_P = ParamSpec("_P") + + +def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to fork a new process for each test function. + See https://github.com/vllm-project/vllm/issues/7053 for more details. + """ + + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Make the process the leader of its own process group + # to avoid sending SIGTERM to the parent process + os.setpgrp() + from _pytest.outcomes import Skipped + + # Create a unique temporary file to store exception info from child + # process. Use test function name and process ID to avoid collisions. + with ( + tempfile.NamedTemporaryFile( + delete=False, + mode="w+b", + prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", + suffix=".exc", + ) as exc_file, + ExitStack() as delete_after, + ): + exc_file_path = exc_file.name + delete_after.callback(os.remove, exc_file_path) + + pid = os.fork() + print(f"Fork a new process to run a test {pid}") + if pid == 0: + # Parent process responsible for deleting, don't delete + # in child. + delete_after.pop_all() + try: + func(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception as e: + import traceback + + tb_string = traceback.format_exc() + + # Try to serialize the exception object first + exc_to_serialize: dict[str, Any] + try: + # First, try to pickle the actual exception with + # its traceback. + exc_to_serialize = {"pickled_exception": e} + # Test if it can be pickled + cloudpickle.dumps(exc_to_serialize) + except (Exception, KeyboardInterrupt): + # Fall back to string-based approach. + exc_to_serialize = { + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": tb_string, + } + try: + with open(exc_file_path, "wb") as f: + cloudpickle.dump(exc_to_serialize, f) + except Exception: + # Fallback: just print the traceback. + print(tb_string) + os._exit(1) + else: + os._exit(0) + else: + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + if _exitcode != 0: + # Try to read the exception from the child process + exc_info = {} + if os.path.exists(exc_file_path): + with ( + contextlib.suppress(Exception), + open(exc_file_path, "rb") as f, + ): + exc_info = cloudpickle.load(f) + + original_exception = exc_info.get("pickled_exception") + if original_exception is not None and isinstance(original_exception, Exception): + # Re-raise the actual exception object if it was + # successfully pickled. + raise original_exception + + if (original_tb := exc_info.get("traceback")) is not None: + # Use string-based traceback for fallback case + raise AssertionError( + f"Test {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode}):\n{original_tb}" + ) from None + + # Fallback to the original generic error + raise AssertionError( + f"function {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode})" + ) from None + + return wrapper + + +def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to spawn a new process for each test function.""" + + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Check if we're already in a subprocess + if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": + # If we are, just run the function directly + return f(*args, **kwargs) + + import torch.multiprocessing as mp + + with suppress(RuntimeError): + mp.set_start_method("spawn") + + # Get the module + module_name = f.__module__ + + # Create a process with environment variable set + env = os.environ.copy() + env["RUNNING_IN_SUBPROCESS"] = "1" + + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "new_process.tmp") + + # `cloudpickle` allows pickling complex functions directly + input_bytes = cloudpickle.dumps((f, output_filepath)) + + repo_root = str(VLLM_PATH.resolve()) + + env = dict(env or os.environ) + env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "") + + cmd = [sys.executable, "-m", f"{module_name}"] + + returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e + + return wrapper + + +def create_new_process_for_each_test( + method: Literal["spawn", "fork"] | None = None, +) -> Callable[[Callable[_P, None]], Callable[_P, None]]: + """Creates a decorator that runs each test function in a new process. + + Args: + method: The process creation method. Can be either "spawn" or "fork". + If not specified, it defaults to "spawn" on ROCm and XPU + platforms and "fork" otherwise. + + Returns: + A decorator to run test functions in separate processes. + """ + if method is None: + # TODO: Find out why spawn is not working correctly on ROCm + # The test content will not run and tests passed immediately. + # For now, using `fork` for ROCm as it can run with `fork` + # and tests are running correctly. + use_spawn = current_platform.is_xpu() + method = "spawn" if use_spawn else "fork" + + assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" + + if method == "fork": + return fork_new_process_for_each_test + + return spawn_new_process_for_each_test diff --git a/tests/e2e/online_serving/__init__.py b/tests/e2e/online_serving/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/e2e/online_serving/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/online_serving/stage_configs/qwen3_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f0161edd2d6086fd38a8f3ab373c8076224017e --- /dev/null +++ b/tests/e2e/online_serving/stage_configs/qwen3_omni_ci.yaml @@ -0,0 +1,103 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0" + max_batch_size: 5 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + load_format: dummy + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 5 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + distributed_executor_backend: "mp" + hf_config_name: talker_config + load_format: dummy + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 100 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + load_format: dummy + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 200 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/tests/e2e/online_serving/stage_configs/qwen3_omni_thinker_ci.yaml b/tests/e2e/online_serving/stage_configs/qwen3_omni_thinker_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89d03966d0648743ccbbcfc1f425625301680a58 --- /dev/null +++ b/tests/e2e/online_serving/stage_configs/qwen3_omni_thinker_ci.yaml @@ -0,0 +1,31 @@ +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0,1" + max_batch_size: 5 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 2 + load_format: dummy + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + detokenize: True + repetition_penalty: 1.05 diff --git a/tests/e2e/online_serving/stage_configs/rocm/qwen3_omni_ci.yaml b/tests/e2e/online_serving/stage_configs/rocm/qwen3_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59642a77b6fcc9337b82252f51e048f6e073c1fc --- /dev/null +++ b/tests/e2e/online_serving/stage_configs/rocm/qwen3_omni_ci.yaml @@ -0,0 +1,95 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 5 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 1 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "1" + max_batch_size: 5 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 1000 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + async_scheduling: false + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2000 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/tests/e2e/online_serving/test_async_omni.py b/tests/e2e/online_serving/test_async_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..d90727e1bc527646c5e65af7012b998c00712190 --- /dev/null +++ b/tests/e2e/online_serving/test_async_omni.py @@ -0,0 +1,161 @@ +import asyncio +import os +import sys +from contextlib import ExitStack +from pathlib import Path + +import pytest +from vllm import SamplingParams +from vllm.inputs import PromptType + +from vllm_omni.entrypoints.async_omni import AsyncOmni, ClientRequestState + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +SEED = 42 + +stage_config = str(Path(__file__).parent / "stage_configs" / "qwen3_omni_thinker_ci.yaml") +model = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + +async def generate( + engine: AsyncOmni, + request_id: str, + prompt: PromptType, + max_tokens: int, +) -> tuple[int, str]: + # Ensure generate doesn't complete too fast for cancellation test. + await asyncio.sleep(0.2) + thinker_sampling_params = SamplingParams( + temperature=0.4, # Deterministic + top_p=0.9, + top_k=1, + max_tokens=max_tokens, + repetition_penalty=1.05, + stop_token_ids=[151645], # Qwen EOS token <|im_end|> + seed=SEED, + ) + + sampling_params_list = [ + thinker_sampling_params, + ] + count = 0 + async for omni_output in engine.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["text"], + ): + stage_id = omni_output.stage_id + out = omni_output.request_output + if stage_id == 0: + num_tokens = sum(len(output.token_ids) for output in out.outputs) + count = num_tokens + + await asyncio.sleep(0.0) + + return count, request_id + + +@pytest.mark.asyncio +async def test_abort(): + with ExitStack() as after: + # Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS. + engine = AsyncOmni( + model=model, + stage_configs_path=stage_config, + shm_threshold_bytes=sys.maxsize, + ) + after.callback(engine.shutdown) + + # Keep token counts modest to reduce flakiness on slow test hardware. + NUM_REQUESTS = 3 + NUM_EXPECTED_TOKENS = 64 + NUM_EXPECTED_TOKENS_LONG = 256 + REQUEST_IDS_TO_ABORT = [1] + + prompt = "Hello my name is Robert and " + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks: list[asyncio.Task] = [] + for idx, request_id in enumerate(request_ids): + max_tokens = NUM_EXPECTED_TOKENS_LONG if (idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS + tasks.append(asyncio.create_task(generate(engine, request_id, prompt, max_tokens))) + + # API server cancels requests when they disconnect. + # Explicitly abort in the engine to avoid orphaned requests hanging. + for idx in REQUEST_IDS_TO_ABORT: + tasks[idx].cancel() + await engine.abort(request_ids[idx]) + await asyncio.sleep(0.1) + + # Confirm the other requests are okay. + for idx, task in enumerate(tasks): + # Confirm that it was actually canceled. + if idx in REQUEST_IDS_TO_ABORT: + with pytest.raises((asyncio.CancelledError, GeneratorExit)): + await asyncio.wait_for(task, timeout=60) + else: + # Otherwise, make sure the request was not impacted. + num_generated_tokens, request_id = await asyncio.wait_for(task, timeout=180) + expected_tokens = NUM_EXPECTED_TOKENS + assert num_generated_tokens == expected_tokens, ( + f"{request_id} generated {num_generated_tokens} but expected {expected_tokens}" + ) + + # Confirm we can do another generation. + request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}" + task = asyncio.create_task(generate(engine, request_id, prompt, NUM_EXPECTED_TOKENS)) + num_generated_tokens, request_id = await task + assert num_generated_tokens == NUM_EXPECTED_TOKENS + await asyncio.sleep(5) + + +@pytest.mark.asyncio +async def test_build_and_log_summary(monkeypatch): + from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e + + RealCRS = ClientRequestState + capture_metrics = {} + + class MockCRS(RealCRS): + def __init__(self, request_id: str): + super().__init__(request_id) + capture_metrics[request_id] = self + + monkeypatch.setattr("vllm_omni.entrypoints.async_omni.ClientRequestState", MockCRS) + monkeypatch.setattr("vllm_omni.entrypoints.client_request_state.ClientRequestState", MockCRS) + + with ExitStack() as after: + # Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS. + engine = AsyncOmni( + model=model, + stage_configs_path=stage_config, + shm_threshold_bytes=sys.maxsize, + ) + after.callback(engine.shutdown) + prompt = "Hello my name is Robert and " + NUM_EXPECTED_TOKENS = 64 + NUM_REQUESTS = 3 + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks: list[asyncio.Task] = [] + for idx, request_id in enumerate(request_ids): + tasks.append(asyncio.create_task(generate(engine, request_id, prompt, NUM_EXPECTED_TOKENS))) + + # Confirm the requests are okay. + for idx, task in enumerate(tasks): + await task + output_modalities = ["text"] + final_stage_id_for_e2e = get_final_stage_id_for_e2e( + output_modalities, engine.output_modalities, engine.stage_list + ) + summary = capture_metrics[request_ids[idx]].metrics.build_and_log_summary(final_stage_id_for_e2e) + + # Check that total tokens matches sum of stage tokens. + assert summary["e2e_total_tokens"] == sum(stage["tokens"] for stage in summary["stages"]) + # Check that total time matches sum of stage times. + assert summary["e2e_total_time_ms"] >= sum(stage["total_time_ms"] for stage in summary["stages"]) diff --git a/tests/e2e/online_serving/test_image_gen_edit.py b/tests/e2e/online_serving/test_image_gen_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..8db0d50fbe465de55011506403a510287124beb9 --- /dev/null +++ b/tests/e2e/online_serving/test_image_gen_edit.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E online serving test for Qwen-Image-Edit-2509 multi-image input. +""" + +import base64 +import os +import signal +import socket +import subprocess +import sys +import threading +import time +from io import BytesIO +from typing import Any + +import openai +import pytest +import requests +from PIL import Image +from vllm.assets.image import ImageAsset +from vllm.utils.network_utils import get_open_port + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +# Increase timeout for downloading assets from S3 (default 5s is too short for CI) +os.environ.setdefault("VLLM_IMAGE_FETCH_TIMEOUT", "60") + +models = ["Qwen/Qwen-Image-Edit-2509"] +test_params = models +t2i_models = ["Tongyi-MAI/Z-Image-Turbo"] + + +class OmniServer: + """Omniserver for vLLM-Omni tests.""" + + def __init__( + self, + model: str, + serve_args: list[str], + *, + env_dict: dict[str, str] | None = None, + ) -> None: + self.model = model + self.serve_args = serve_args + self.env_dict = env_dict + self.proc: subprocess.Popen | None = None + self.host = "127.0.0.1" + self.port = get_open_port() + + def _start_server(self) -> None: + """Start the vLLM-Omni server subprocess.""" + env = os.environ.copy() + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if self.env_dict is not None: + env.update(self.env_dict) + + cmd = [ + sys.executable, + "-m", + "vllm_omni.entrypoints.cli.main", + "serve", + self.model, + "--omni", + "--host", + self.host, + "--port", + str(self.port), + ] + self.serve_args + + print(f"Launching OmniServer with: {' '.join(cmd)}") + self.proc = subprocess.Popen( + cmd, + env=env, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root + start_new_session=True, + ) + + # Wait for server to be ready + max_wait = 600 # 10 minutes + start_time = time.time() + while time.time() - start_time < max_wait: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) + result = sock.connect_ex((self.host, self.port)) + if result == 0: + print(f"Server ready on {self.host}:{self.port}") + return + except Exception: + pass + time.sleep(2) + + raise RuntimeError(f"Server failed to start within {max_wait} seconds") + + def __enter__(self): + self._start_server() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.proc: + try: + os.killpg(self.proc.pid, signal.SIGTERM) + except ProcessLookupError: + pass + + try: + self.proc.wait(timeout=30) + except subprocess.TimeoutExpired: + try: + os.killpg(self.proc.pid, signal.SIGKILL) + except ProcessLookupError: + pass + self.proc.wait() + + +@pytest.fixture +def omni_server(request): + """Start vLLM-Omni server as a subprocess with actual model weights.""" + model = request.param + with OmniServer(model, ["--num-gpus", "1"]) as server: + yield server + + +@pytest.fixture +def client(omni_server): + """OpenAI client for the running vLLM-Omni server.""" + return openai.OpenAI( + base_url=f"http://{omni_server.host}:{omni_server.port}/v1", + api_key="EMPTY", + ) + + +@pytest.fixture(scope="session") +def base64_encoded_images() -> list[str]: + """Base64 encoded PNG images for testing.""" + images = [ + ImageAsset("cherry_blossom").pil_image.convert("RGB"), + ImageAsset("stop_sign").pil_image.convert("RGB"), + ] + encoded: list[str] = [] + for img in images: + with BytesIO() as buffer: + img.save(buffer, format="PNG") + encoded.append(base64.b64encode(buffer.getvalue()).decode("utf-8")) + return encoded + + +def dummy_messages_from_image_data( + image_data_urls: list[str], + content_text: str = "Combine these two images into one scene.", +): + """Create messages with image data URLs for OpenAI API.""" + content = [{"type": "text", "text": content_text}] + for image_url in image_data_urls: + content.append({"type": "image_url", "image_url": {"url": image_url}}) + return [{"role": "user", "content": content}] + + +def _extract_image_data_url(message_content) -> str: + assert isinstance(message_content, list) and len(message_content) >= 1 + content_part = message_content[0] + if isinstance(content_part, dict): + image_url = content_part.get("image_url", {}).get("url", "") + else: + image_url_obj = getattr(content_part, "image_url", None) + if isinstance(image_url_obj, dict): + image_url = image_url_obj.get("url", "") + else: + image_url = getattr(image_url_obj, "url", "") + assert isinstance(image_url, str) and image_url + return image_url + + +def _decode_data_url_to_image_bytes(data_url: str) -> bytes: + assert data_url.startswith("data:image") + _, b64_data = data_url.split(",", 1) + return base64.b64decode(b64_data) + + +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_i2i_multi_image_input_qwen_image_edit_2509( + omni_server, + base64_encoded_images: list[str], +) -> None: + """Test multi-image input editing via OpenAI API with concurrent requests.""" + image_data_urls = [f"data:image/png;base64,{img}" for img in base64_encoded_images] + messages = dummy_messages_from_image_data(image_data_urls) + + barrier = threading.Barrier(2) + results: list[tuple[int, int]] = [] + + def _call_chat(width: int, height: int) -> None: + client = openai.OpenAI( + base_url=f"http://{omni_server.host}:{omni_server.port}/v1", + api_key="EMPTY", + ) + barrier.wait() + chat_completion = client.chat.completions.create( + model=omni_server.model, + messages=messages, + extra_body={ + "height": height, + "width": width, + "num_inference_steps": 2, + "guidance_scale": 0.0, + "seed": 42, + }, + ) + + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "stop" + assert choice.message.role == "assistant" + + image_data_url = _extract_image_data_url(choice.message.content) + image_bytes = _decode_data_url_to_image_bytes(image_data_url) + img = Image.open(BytesIO(image_bytes)) + img.load() + results.append(img.size) + + threads = [ + threading.Thread(target=_call_chat, args=(1248, 832)), + threading.Thread(target=_call_chat, args=(1024, 768)), + ] + for t in threads: + t.start() + for t in threads: + t.join() + + # TODO @ZJY + # assert (1248, 832) in results + # assert (1024, 768) in results + + +@pytest.mark.parametrize("omni_server", t2i_models, indirect=True) +def test_t2i_concurrent_requests_different_sizes(omni_server) -> None: + """Test /v1/images/generations concurrent requests with different sizes.""" + base_url = f"http://{omni_server.host}:{omni_server.port}" + url = f"{base_url}/v1/images/generations" + + barrier = threading.Barrier(2) + results: list[tuple[int, int]] = [] + + def _call_generate(size: str) -> None: + payload: dict[str, Any] = { + "prompt": "cute cat playing with a ball", + "n": 1, + "size": size, + "response_format": "b64_json", + "num_inference_steps": 2, + } + barrier.wait() + response = requests.post(url, json=payload, timeout=120) + assert response.status_code == 200 + data = response.json() + image_b64 = data["data"][0]["b64_json"] + image_bytes = base64.b64decode(image_b64) + img = Image.open(BytesIO(image_bytes)) + img.load() + results.append(img.size) + + threads = [ + threading.Thread(target=_call_generate, args=("512x512",)), + threading.Thread(target=_call_generate, args=("768x512",)), + ] + for t in threads: + t.start() + for t in threads: + t.join() + + assert (512, 512) in results + assert (768, 512) in results diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..e912c420dc2119c239d74285ae3d992be82609d7 --- /dev/null +++ b/tests/e2e/online_serving/test_images_generations_lora.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E online serving test for /v1/images/generations with per-request LoRA. + +This validates: +- The API server accepts a per-request `lora` object in the Images API payload. +- LoRA can be switched per request (adapter A -> adapter B -> no LoRA). +- Output correctness is asserted using a small image slice with tolerance. +""" + +import base64 +import json +import os +from io import BytesIO +from pathlib import Path + +import numpy as np +import pytest +import requests +import torch +from PIL import Image +from safetensors.torch import save_file + +from tests.conftest import OmniServer + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODEL = "Tongyi-MAI/Z-Image-Turbo" + + +PROMPT = "a photo of a cat sitting on a laptop keyboard" +SIZE = "256x256" +SEED = 42 + + +@pytest.fixture(scope="module") +def omni_server(): + with OmniServer(MODEL, ["--num-gpus", "1"]) as server: + yield server + + +def _write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0): + adapter_dir.mkdir(parents=True, exist_ok=True) + + # Z-Image transformer uses dim=3840 by default. + dim = 3840 + module_name = "transformer.layers.0.attention.to_qkv" + rank = 1 + + lora_a = torch.zeros((rank, dim), dtype=torch.float32) + lora_a[0, 0] = 1.0 + + # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1). + lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) + if q_scale: + lora_b[:dim, 0] = q_scale + if k_scale: + lora_b[dim : 2 * dim, 0] = k_scale + if v_scale: + lora_b[2 * dim :, 0] = v_scale + + save_file( + { + f"base_model.model.{module_name}.lora_A.weight": lora_a, + f"base_model.model.{module_name}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps( + { + "r": rank, + "lora_alpha": rank, + "target_modules": [module_name], + } + ), + encoding="utf-8", + ) + + +def _post_images(server: OmniServer, payload: dict) -> Image.Image: + url = f"http://{server.host}:{server.port}/v1/images/generations" + resp = requests.post(url, json=payload, headers={"Authorization": "Bearer EMPTY"}, timeout=900) + resp.raise_for_status() + data = resp.json() + b64 = data["data"][0]["b64_json"] + img_bytes = base64.b64decode(b64) + img = Image.open(BytesIO(img_bytes)) + img.load() + return img.convert("RGB") + + +def _image_blue_tail_slice(img: Image.Image) -> np.ndarray: + arr = np.asarray(img, dtype=np.uint8) + assert arr.ndim == 3 and arr.shape[-1] == 3 + tail = arr[-3:, -3:, -1].astype(np.float32) + assert tail.shape == (3, 3) + return tail + + +def _slice_diff_stats(actual: np.ndarray, expected: np.ndarray) -> tuple[float, float]: + diff = np.abs(actual - expected) + return float(diff.max()), float(diff.mean()) + + +def _assert_slice_close( + actual: np.ndarray, + expected: np.ndarray, + *, + label: str, + base_max: float, + base_mean: float, +) -> None: + assert actual.shape == (3, 3) + assert expected.shape == (3, 3) + max_diff, mean_diff = _slice_diff_stats(actual, expected) + # NOTE: Different attention backends / torch.compile can introduce small + # floating-point drift that shows up as a few LSBs in uint8 pixels. Keep + # the reset check tolerant but bounded to avoid flaky CI. + max_thresh = max(10.0, base_max + 4.0) + mean_thresh = max(6.0, base_mean + 4.0) + assert max_diff <= max_thresh and mean_diff <= mean_thresh, ( + f"{label} slice mismatch (max={max_diff:.1f} > {max_thresh:.1f} or " + f"mean={mean_diff:.1f} > {mean_thresh:.1f}): {actual.tolist()}" + ) + + +def _assert_slice_diff(actual: np.ndarray, baseline: np.ndarray, *, label: str) -> None: + assert actual.shape == (3, 3) + assert baseline.shape == (3, 3) + diff = np.abs(actual - baseline).mean() + assert diff > 0.1, f"{label} slice diff too small: {diff} ({actual.tolist()} vs {baseline.tolist()})" + + +def _basic_payload() -> dict: + return { + "prompt": PROMPT, + "n": 1, + "size": SIZE, + "num_inference_steps": 2, + "guidance_scale": 0.0, + "seed": SEED, + } + + +def test_images_generations_per_request_lora_switching(omni_server: OmniServer, tmp_path: Path) -> None: + # Base generation. + base_img = _post_images(omni_server, _basic_payload()) + base_slice = _image_blue_tail_slice(base_img) + base_ref_img = _post_images(omni_server, _basic_payload()) + base_ref_slice = _image_blue_tail_slice(base_ref_img) + base_ref_max, base_ref_mean = _slice_diff_stats(base_ref_slice, base_slice) + + # Adapter A: apply delta to V slice only. + lora_a_dir = tmp_path / "zimage_lora_a" + _write_zimage_lora(lora_a_dir, v_scale=8.0) + payload_a = _basic_payload() + payload_a["lora"] = {"name": "a", "path": str(lora_a_dir), "scale": 64.0} + img_a = _post_images(omni_server, payload_a) + a_slice = _image_blue_tail_slice(img_a) + _assert_slice_diff(a_slice, base_slice, label="lora_a_vs_base") + a_vs_base = float(np.abs(a_slice - base_slice).mean()) + + # Adapter B: apply delta to K slice only (should differ from adapter A). + lora_b_dir = tmp_path / "zimage_lora_b" + _write_zimage_lora(lora_b_dir, k_scale=4.0) + payload_b = _basic_payload() + payload_b["lora"] = {"name": "b", "path": str(lora_b_dir), "scale": 64.0} + img_b = _post_images(omni_server, payload_b) + b_slice = _image_blue_tail_slice(img_b) + _assert_slice_diff(b_slice, base_slice, label="lora_b_vs_base") + _assert_slice_diff(b_slice, a_slice, label="lora_b_vs_lora_a") + b_vs_base = float(np.abs(b_slice - base_slice).mean()) + b_vs_a = float(np.abs(b_slice - a_slice).mean()) + + # Ensure switching back to no-LoRA restores the base output. + base_img_2 = _post_images(omni_server, _basic_payload()) + base_slice_2 = _image_blue_tail_slice(base_img_2) + _, base_reset_mean = _slice_diff_stats(base_slice_2, base_slice) + _assert_slice_close( + base_slice_2, + base_slice, + label="base_after_reset", + base_max=base_ref_max, + base_mean=base_ref_mean, + ) + + # Ensure LoRA effects are clearly above the baseline drift. + min_delta = max(base_reset_mean + 1.0, 1.5) + assert a_vs_base > min_delta, f"lora_a_vs_base drift too small: {a_vs_base} <= {min_delta}" + assert b_vs_base > min_delta, f"lora_b_vs_base drift too small: {b_vs_base} <= {min_delta}" + assert b_vs_a > min_delta, f"lora_b_vs_lora_a drift too small: {b_vs_a} <= {min_delta}" diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..073419fb8381273a5b33ff2f1b94dc75a0e16803 --- /dev/null +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E Online tests for Qwen3-Omni model with video input and audio output. +""" + +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +import concurrent.futures +import threading +import time +from pathlib import Path + +import openai +import pytest + +from tests.conftest import ( + OmniServer, + convert_audio_to_text, + cosine_similarity_text, + dummy_messages_from_mix_data, + generate_synthetic_audio, + generate_synthetic_image, + generate_synthetic_video, + merge_base64_and_convert_to_text, + modify_stage_config, +) +from vllm_omni.platforms import current_omni_platform + +models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] + + +def get_default_config(): + return str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml") + + +def get_chunk_config(): + path = modify_stage_config( + get_default_config(), + updates={ + "async_chunk": True, + "stage_args": { + 0: { + "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk" + }, + 1: { + "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk" + }, + }, + }, + deletes={"stage_args": {2: ["custom_process_input_func"]}}, + ) + return path + + +CHUNK_CONFIG_PATH = get_chunk_config() +# CI stage config for 2xH100-80G GPUs or AMD GPU MI325 +if current_omni_platform.is_rocm(): + # ROCm stage config optimized for MI325 GPU + stage_configs = [str(Path(__file__).parent / "stage_configs" / "rocm" / "qwen3_omni_ci.yaml")] +else: + stage_configs = [get_default_config(), CHUNK_CONFIG_PATH] + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +_omni_server_lock = threading.Lock() + + +@pytest.fixture(scope="module") +def omni_server(request): + """Start vLLM-Omni server as a subprocess with actual model weights. + Uses session scope so the server starts only once for the entire test session. + Multi-stage initialization can take 10-20+ minutes. + """ + with _omni_server_lock: + model, stage_config_path = request.param + + print(f"Starting OmniServer with model: {model}") + + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server: + print("OmniServer started successfully") + yield server + print("OmniServer stopping...") + + print("OmniServer stopped") + + +@pytest.fixture +def client(omni_server): + """OpenAI client for the running vLLM-Omni server.""" + return openai.OpenAI( + base_url=f"http://{omni_server.host}:{omni_server.port}/v1", + api_key="EMPTY", + ) + + +def get_system_prompt(): + return { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + } + + +def dummy_messages_from_video_data( + video_data_url: str, + content_text: str = "Describe the video briefly.", +): + """Create messages with video data URL for OpenAI API.""" + return [ + get_system_prompt(), + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_data_url}}, + {"type": "text", "text": content_text}, + ], + }, + ] + + +def get_prompt(prompt_type="text_only"): + prompts = { + "text_only": "What is the capital of China? Answer in 20 words.", + "mix": "What is recited in the audio? What is in this image? Describe the video briefly.", + } + return prompts.get(prompt_type, prompts["text_only"]) + + +def get_max_batch_size(size_type="few"): + batch_sizes = {"few": 5, "medium": 100, "large": 256} + return batch_sizes.get(size_type, 5) + + +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> None: + """ + Test multi-modal input processing and text/audio output generation via OpenAI API. + Deploy Setting: default yaml + Input Modal: text + audio + video + image + Output Modal: text + audio + Input Setting: stream=True + Datasets: single request + """ + + # Test single completion + e2e_list = list() + video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}" + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + video_data_url=video_data_url, + image_data_url=image_data_url, + audio_data_url=audio_data_url, + content_text=get_prompt("mix"), + ) + + # Test single completion + start_time = time.perf_counter() + chat_completion = client.chat.completions.create(model=omni_server.model, messages=messages, stream=True) + + text_content = "" + audio_data = [] + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + modality = getattr(chunk, "modality", None) + + if modality == "audio" and content: + audio_data.append(content) + elif modality == "text" and content: + # Text chunk - accumulate text content + text_content += content if content else "" + + # Verify E2E + current_e2e = time.perf_counter() - start_time + print(f"the request e2e is: {current_e2e}") + # TODO: Verify the E2E latency after confirmation baseline. + e2e_list.append(current_e2e) + + print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}") + # Verify all completions succeeded + assert audio_data is not None, "No audio output is generated" + + # Verify text output success + assert text_content is not None and len(text_content) >= 2, "No text output is generated" + assert any( + keyword in text_content.lower() for keyword in ["square", "quadrate", "sphere", "globe", "circle", "round"] + ), "The output does not contain any of the keywords." + + # Verify text output same as audio output + audio_content = merge_base64_and_convert_to_text(audio_data) + print(f"text content is: {text_content}") + print(f"audio content is: {audio_content}") + similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) + print(f"similarity is: {similarity}") + assert similarity > 0.9, "The audio content is not same as the text" + + +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_text_to_text_audio_001(client: openai.OpenAI, omni_server) -> None: + """ + Test text input processing and text/audio output generation via OpenAI API. + Deploy Setting: default yaml + Input Modal: text + Output Modal: text + audio + Datasets: few requests + """ + + num_concurrent_requests = get_max_batch_size() + messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt()) + + e2e_list = list() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor: + # Submit multiple completion requests concurrently + futures = [ + executor.submit(client.chat.completions.create, model=omni_server.model, messages=messages) + for _ in range(num_concurrent_requests) + ] + start_time = time.perf_counter() + # Wait for all requests to complete and collect results + chat_completions = list() + for future in concurrent.futures.as_completed(futures): + chat_completions.append(future.result()) + # Verify E2E + current_e2e = time.perf_counter() - start_time + print(f"the request e2e is: {current_e2e}") + # TODO: Verify the E2E latency after confirmation baseline. + e2e_list.append(current_e2e) + + print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}") + # Verify all completions succeeded + assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded." + for chat_completion in chat_completions: + # Verify audio output success + audio_data = None + text_content = None + for choice in chat_completion.choices: + if choice.message.audio is not None: + audio_message = choice.message + audio_data = audio_message.audio.data + assert audio_message.audio.expires_at > time.time(), "The generated audio has expired." + + if choice.message.content is not None: + # Verify text output success + text_content = choice.message.content + assert "beijing" in text_content.lower(), "The output do not contain keywords." + + # Verify text output same as audio output + audio_content = convert_audio_to_text(audio_data) + print(f"text content is: {text_content}") + print(f"audio content is: {audio_content}") + similarity = cosine_similarity_text(audio_content.lower(), text_content.lower()) + print(f"similarity is: {similarity}") + assert similarity > 0.9, "The audio content is not same as the text" diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py new file mode 100644 index 0000000000000000000000000000000000000000..5db07fd89a5a1faa13b1c32af791095db1564b3b --- /dev/null +++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E Online tests for Qwen3-Omni model. +""" + +import concurrent.futures +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +import time +from pathlib import Path + +import openai +import pytest + +from tests.conftest import ( + OmniServer, + convert_audio_to_text, + cosine_similarity_text, + dummy_messages_from_mix_data, + generate_synthetic_audio, + generate_synthetic_image, + modify_stage_config, +) + +models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] + +# CI stage config for 2*H100-80G GPUs +stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")] + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +def client(omni_server): + """OpenAI client for the running vLLM-Omni server.""" + return openai.OpenAI( + base_url=f"http://{omni_server.host}:{omni_server.port}/v1", + api_key="EMPTY", + ) + + +def get_system_prompt(): + return { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + } + + +def get_prompt(prompt_type="text_only"): + prompts = { + "text_only": "What is the capital of China?", + "mix": "What is recited in the audio? What is in this image? Describe the video briefly.", + } + return prompts.get(prompt_type, prompts["text_only"]) + + +def get_max_batch_size(size_type="few"): + batch_sizes = {"few": 5, "medium": 100, "large": 256} + return batch_sizes.get(size_type, 5) + + +def get_deploy_config(deploy_type="TP1"): + result = { + "TP1": { + "stage_args": { + 0: { + "engine_args.gpu_memory_utilization": 0.95, + "engine_args.tensor_parallel_size": 1, + "runtime.devices": "0", + }, + 2: {"runtime.devices": "1"}, + } + } + } + return result.get(deploy_type, result["TP1"]) + + +@pytest.mark.parametrize("test_config", test_params) +def test_text_to_text_001(test_config: tuple[str, str]) -> None: + """Test processing text, generating text output via OpenAI API.""" + model, stage_config_path = test_config + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: + messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt()) + + # Test single completion + api_client = client(server) + start_time = time.perf_counter() + chat_completion = api_client.chat.completions.create( + model=server.model, messages=messages, max_tokens=20, modalities=["text"] + ) + # Verify E2E + print(f"the request e2e is: {time.perf_counter() - start_time}") + # TODO: Verify the E2E latency after confirmation baseline. + + # Verify only output text + assert len(chat_completion.choices) == 1, "The generated content includes more than just text." + + # Verify text output success + text_choice = chat_completion.choices[0] + assert text_choice.message.content is not None, "No text output is generated" + assert chat_completion.usage.completion_tokens <= 20, "The output length more than the requested max_tokens." + assert "beijing" in text_choice.message.content.lower(), "The output do not contain keywords." + + +@pytest.mark.parametrize("test_config", test_params) +def test_audio_to_text_001(test_config: tuple[str, str]) -> None: + """Test processing text, generating text output via OpenAI API.""" + model, stage_config_path = test_config + deploy_config = get_deploy_config() + deploy_config[0]["default_sampling_params.ignore_eos"] = True + stage_config_path = modify_stage_config(stage_config_path, deploy_config) + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: + audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(1, 1)['base64']}" + messages = dummy_messages_from_mix_data(audio_data_url=audio_data_url) + # Test single completion + api_client = client(server) + start_time = time.perf_counter() + chat_completion = api_client.chat.completions.create( + model=server.model, messages=messages, max_tokens=200, modalities=["text"] + ) + # Verify only output text + assert len(chat_completion.choices) == 1, "The generated content includes more than just text." + + # Verify text output success + text_choice = chat_completion.choices[0] + assert text_choice.message.content is not None, "No text output is generated" + assert chat_completion.usage.completion_tokens == 200, ( + "The output length differs from the requested max_tokens." + ) + + # Verify E2E + print(f"the request e2e is: {time.perf_counter() - start_time}") + # TODO: Verify the E2E latency after confirmation baseline. + + +@pytest.mark.parametrize("test_config", test_params) +def test_audio_to_text_audio_001(test_config: tuple[str, str]) -> None: + """Test processing text, generating audio output via OpenAI API.""" + + model, stage_config_path = test_config + num_concurrent_requests = get_max_batch_size() + stage_config_path = modify_stage_config( + stage_config_path, + { + "stage_args": { + 0: {"runtime.max_batch_size": num_concurrent_requests}, + 1: {"runtime.max_batch_size": num_concurrent_requests}, + } + }, + ) + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: + audio_data_url = [] + for _ in range(5): + audio_data_url.append(f"data:audio/wav;base64,{generate_synthetic_audio(1, 5)['base64']}") + + messages = dummy_messages_from_mix_data(audio_data_url=audio_data_url) + + # Test single completion + api_client = client(server) + e2e_list = list() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor: + # Submit multiple completion requests concurrently + futures = [ + executor.submit(api_client.chat.completions.create, model=server.model, messages=messages) + for _ in range(num_concurrent_requests) + ] + start_time = time.perf_counter() + # Wait for all requests to complete and collect results + chat_completions = list() + for future in concurrent.futures.as_completed(futures): + chat_completions.append(future.result()) + # Verify E2E + current_e2e = time.perf_counter() - start_time + print(f"the request e2e is: {current_e2e}") + # TODO: Verify the E2E latency after confirmation baseline. + e2e_list.append(current_e2e) + + print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}") + # Verify all completions succeeded + assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded." + for chat_completion in chat_completions: + # Verify audio output success + audio_message = chat_completion.choices[1].message + audio_data = audio_message.audio.data + assert audio_data is not None, "No audio output is generated" + assert audio_message.audio.expires_at > time.time(), "The generated audio has expired." + + # Verify text output success + text_choice = chat_completion.choices[0] + text_content = text_choice.message.content + assert text_choice.message.content is not None, "No text output is generated" + + # Verify text output same as audio output + audio_content = convert_audio_to_text(audio_data) + print(f"text content is: {text_content}") + print(f"audio content is: {audio_content}") + similarity = cosine_similarity_text(audio_content, text_content) + print(f"similarity between audio and text is: {similarity}") + assert similarity > 0.9, "The audio content is not same as the text" + + +@pytest.mark.parametrize("test_config", test_params) +def test_image_to_text_001(test_config: tuple[str, str]) -> None: + """Test processing text, generating text output via OpenAI API.""" + model, stage_config_path = test_config + deploy_config = get_deploy_config() + stage_config_path = modify_stage_config(stage_config_path, deploy_config) + + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + messages = dummy_messages_from_mix_data(image_data_url=image_data_url) + # Test single completion + api_client = client(server) + start_time = time.perf_counter() + chat_completion = api_client.chat.completions.create( + model=server.model, messages=messages, max_tokens=100, modalities=["text"] + ) + # Verify E2E + print(f"the request e2e is: {time.perf_counter() - start_time}") + # TODO: Verify the E2E latency after confirmation baseline. + + # Verify only output text + assert len(chat_completion.choices) == 1, "The generated content includes more than just text." + + # Verify text output success + text_choice = chat_completion.choices[0] + text_content = text_choice.message.content + assert text_content is not None, "No text output is generated" + assert chat_completion.usage.completion_tokens <= 100, "The output length more than the requested max_tokens." + assert "square" in text_content.lower(), "The output do not contain keywords." + + +@pytest.mark.parametrize("test_config", test_params) +def test_image_to_text_audio_001(test_config: tuple[str, str]) -> None: + """Test processing text, generating audio output via OpenAI API.""" + + model, stage_config_path = test_config + num_concurrent_requests = 5 + stage_config_path = modify_stage_config( + stage_config_path, + { + "stage_args": { + 0: {"runtime.max_batch_size": num_concurrent_requests}, + 1: {"runtime.max_batch_size": num_concurrent_requests}, + } + }, + ) + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: + image_data_url = [] + for _ in range(4): + image_data_url.append(f"data:image/jpeg;base64,{generate_synthetic_image(1280, 720)['base64']}") + + messages = dummy_messages_from_mix_data(image_data_url=image_data_url) + + # Test single completion + api_client = client(server) + e2e_list = list() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor: + # Submit multiple completion requests concurrently + futures = [ + executor.submit( + api_client.chat.completions.create, + model=server.model, + messages=messages, + ) + for _ in range(num_concurrent_requests) + ] + start_time = time.perf_counter() + # Wait for all requests to complete and collect results + chat_completions = list() + for future in concurrent.futures.as_completed(futures): + chat_completions.append(future.result()) + # Verify E2E + current_e2e = time.perf_counter() - start_time + print(f"the request e2e is: {current_e2e}") + # TODO: Verify the E2E latency after confirmation baseline. + e2e_list.append(current_e2e) + + print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}") + # Verify all completions succeeded + assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded." + for chat_completion in chat_completions: + # Verify audio output success + audio_message = chat_completion.choices[1].message + audio_data = audio_message.audio.data + assert audio_data is not None, "No audio output is generated" + assert audio_message.audio.expires_at > time.time(), "The generated audio has expired." + + # Verify text output success + text_choice = chat_completion.choices[0] + text_content = text_choice.message.content + assert text_content is not None, "No text output is generated" + assert "square" in text_content.lower(), "The output do not contain keywords." + + # Verify text output same as audio output + audio_content = convert_audio_to_text(audio_data) + print(f"text content is: {text_content}") + print(f"audio content is: {audio_content}") + similarity = cosine_similarity_text(audio_content, text_content) + print(f"similarity between audio and text is: {similarity}") + assert similarity > 0.9, "The audio content is not same as the text" diff --git a/tests/e2e/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/qwen3_omni_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..479e4d6e99d82b54123e393a6ac1d334cf306fdd --- /dev/null +++ b/tests/e2e/stage_configs/qwen3_omni_ci.yaml @@ -0,0 +1,98 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: +- stage_id: 0 + runtime: + devices: "0" + max_batch_size: 5 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + max_num_batched_tokens: 32768 + max_model_len: 32768 + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 1 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + ignore_eos: False + detokenize: True + repetition_penalty: 1.05 + +- stage_id: 1 + runtime: + devices: "1" + max_batch_size: 5 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + enable_prefix_caching: false + max_num_batched_tokens: 32768 + max_model_len: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 1000 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + +- stage_id: 2 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 100000 + hf_config_name: thinker_config + async_scheduling: false + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2000 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/tests/entrypoints/openai_api/__init__.py b/tests/entrypoints/openai_api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py new file mode 100644 index 0000000000000000000000000000000000000000..3df484b1e27ca9ae35a342b870a49caf9e5c4630 --- /dev/null +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -0,0 +1,816 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for async image generation API endpoints. + +This module contains unit tests and integration tests (with mocking) for the +OpenAI-compatible async text-to-image generation API endpoints in api_server.py. +""" + +import base64 +import io +from argparse import Namespace +from unittest.mock import AsyncMock, Mock + +import pytest +from fastapi.testclient import TestClient +from PIL import Image +from vllm import SamplingParams + +from vllm_omni.entrypoints.openai.image_api_utils import ( + encode_image_base64, + parse_size, +) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# Unit Tests + + +def test_parse_size_valid(): + """Test size parsing with valid inputs""" + assert parse_size("1024x1024") == (1024, 1024) + assert parse_size("512x768") == (512, 768) + assert parse_size("256x256") == (256, 256) + assert parse_size("1792x1024") == (1792, 1024) + assert parse_size("1024x1792") == (1024, 1792) + + +def test_parse_size_invalid(): + """Test size parsing with invalid inputs""" + with pytest.raises(ValueError, match="Invalid size format"): + parse_size("invalid") + + with pytest.raises(ValueError, match="Invalid size format"): + parse_size("1024") + + with pytest.raises(ValueError, match="Invalid size format"): + parse_size("1024x") + + with pytest.raises(ValueError, match="Invalid size format"): + parse_size("x1024") + + +def test_parse_size_negative(): + """Test size parsing with negative or zero dimensions""" + with pytest.raises(ValueError, match="positive integers"): + parse_size("0x1024") + + with pytest.raises(ValueError, match="positive integers"): + parse_size("1024x0") + + with pytest.raises(ValueError): + parse_size("-1024x1024") + + +def test_parse_size_edge_cases(): + """Test size parsing with edge cases like empty strings and non-integers""" + # Empty string + with pytest.raises(ValueError, match="non-empty string"): + parse_size("") + + # Non-integer dimensions + with pytest.raises(ValueError, match="must be integers"): + parse_size("abc x def") + + with pytest.raises(ValueError, match="must be integers"): + parse_size("1024.5x768.5") + + # Missing separator (user might forget 'x') + with pytest.raises(ValueError, match="separator"): + parse_size("1024 1024") + + +def test_encode_image_base64(): + """Test image encoding to base64""" + # Create a simple test image + img = Image.new("RGB", (64, 64), color="red") + b64_str = encode_image_base64(img) + + # Should be valid base64 + assert isinstance(b64_str, str) + assert len(b64_str) > 0 + + # Should decode back to PNG + decoded = base64.b64decode(b64_str) + decoded_img = Image.open(io.BytesIO(decoded)) + + # Verify properties + assert decoded_img.size == (64, 64) + assert decoded_img.format == "PNG" + + +# Integration Tests (with mocking) + + +class MockGenerationResult: + """Mock result object from AsyncOmniDiffusion.generate()""" + + def __init__(self, images): + self.images = images + + +class FakeAsyncOmni: + """Fake AsyncOmni that yields a single diffusion output.""" + + def __init__(self): + self.stage_list = ["llm", "diffusion"] + self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()] + self.captured_sampling_params_list = None + self.captured_prompt = None + + async def generate(self, prompt, request_id, sampling_params_list): + self.captured_sampling_params_list = sampling_params_list + self.captured_prompt = prompt + images = [Image.new("RGB", (64, 64), color="green")] + yield MockGenerationResult(images) + + +@pytest.fixture +def mock_async_diffusion(): + """Mock AsyncOmniDiffusion instance that returns fake images""" + mock = Mock() + mock.is_running = True # For health endpoint + mock.check_health = AsyncMock() # For LLM mode health check + + async def generate(**kwargs): + # Return n PIL images wrapped in result object + print("!!!!!!!!!!!!!!!!!!!!! kwargs", kwargs) + n = kwargs["sampling_params_list"][0].num_outputs_per_prompt + mock.captured_sampling_params_list = kwargs["sampling_params_list"] + mock.captured_prompt = kwargs["prompt"] + images = [Image.new("RGB", (64, 64), color="blue") for _ in range(n)] + return MockGenerationResult(images) + + mock.generate = AsyncMock(side_effect=generate) + return mock + + +@pytest.fixture +def test_client(mock_async_diffusion): + """Create test client with mocked async diffusion engine""" + from fastapi import FastAPI + + from vllm_omni.entrypoints.openai.api_server import router + + app = FastAPI() + app.include_router(router) + + # Set up app state with diffusion engine + app.state.engine_client = mock_async_diffusion + app.state.diffusion_engine = mock_async_diffusion # Also set for health endpoint + app.state.stage_configs = [{"stage_type": "diffusion"}] + app.state.diffusion_model_name = "Qwen/Qwen-Image" # For models endpoint + app.state.args = Namespace( + default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}', + max_generated_image_size=4096, # 64*64 + ) + + return TestClient(app) + + +@pytest.fixture +def async_omni_test_client(): + """Create test client with mocked AsyncOmni engine.""" + from fastapi import FastAPI + + from vllm_omni.entrypoints.openai.api_server import router + + app = FastAPI() + app.include_router(router) + + app.state.engine_client = FakeAsyncOmni() + app.state.stage_configs = [{"stage_type": "llm"}, {"stage_type": "diffusion"}] + app.state.args = Namespace( + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + max_generated_image_size=4096, # 64*64 + ) + return TestClient(app) + + +def test_health_endpoint(test_client): + """Test health check endpoint for diffusion mode""" + response = test_client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +def test_health_endpoint_no_engine(): + """Test health check endpoint when no engine is initialized""" + from fastapi import FastAPI + + from vllm_omni.entrypoints.openai.api_server import router + + app = FastAPI() + app.include_router(router) + # Don't set any engine + + client = TestClient(app) + response = client.get("/health") + assert response.status_code == 503 + data = response.json() + assert data["status"] == "unhealthy" + + +def test_models_endpoint(test_client): + """Test /v1/models endpoint for diffusion mode""" + response = test_client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "Qwen/Qwen-Image" + assert data["data"][0]["object"] == "model" + + +def test_models_endpoint_no_engine(): + """Test /v1/models endpoint when no engine is initialized""" + from fastapi import FastAPI + + from vllm_omni.entrypoints.openai.api_server import router + + app = FastAPI() + app.include_router(router) + # Don't set any engine + + client = TestClient(app) + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 0 + + +def test_generate_single_image(test_client): + """Test generating a single image""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "n": 1, + "size": "1024x1024", + }, + ) + assert response.status_code == 200 + data = response.json() + + # Check response structure + assert "created" in data + assert isinstance(data["created"], int) + assert "data" in data + assert len(data["data"]) == 1 + assert "b64_json" in data["data"][0] + + # Verify image can be decoded + img_bytes = base64.b64decode(data["data"][0]["b64_json"]) + img = Image.open(io.BytesIO(img_bytes)) + assert img.size == (64, 64) # Our mock returns 64x64 images + + +def test_generate_images_async_omni_sampling_params(async_omni_test_client): + """Test AsyncOmni path uses per-stage sampling params.""" + response = async_omni_test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "n": 2, + "size": "256x256", + "seed": 7, + }, + ) + assert response.status_code == 200 + engine = async_omni_test_client.app.state.engine_client + captured = engine.captured_sampling_params_list + assert captured is not None + assert len(captured) == 2 + assert captured[0].temperature == 0.1 + assert captured[1].num_outputs_per_prompt == 2 + assert captured[1].height == 256 + assert captured[1].width == 256 + assert captured[1].seed == 7 + + +def test_generate_multiple_images(test_client): + """Test generating multiple images""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a dog", + "n": 3, + "size": "512x512", + }, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["data"]) == 3 + + # All images should be valid + for img_data in data["data"]: + assert "b64_json" in img_data + img_bytes = base64.b64decode(img_data["b64_json"]) + img = Image.open(io.BytesIO(img_bytes)) + assert img.format == "PNG" + + +def test_with_negative_prompt(test_client): + """Test with negative prompt""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "beautiful landscape", + "negative_prompt": "blurry, low quality", + "size": "1024x1024", + }, + ) + assert response.status_code == 200 + + +def test_with_seed(test_client): + """Test with seed for reproducibility""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a tree", + "seed": 42, + "size": "1024x1024", + }, + ) + assert response.status_code == 200 + + +def test_with_custom_parameters(test_client): + """Test with custom diffusion parameters""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a mountain", + "size": "1024x1024", + "num_inference_steps": 100, + "true_cfg_scale": 5.5, + "seed": 123, + }, + ) + assert response.status_code == 200 + + +def test_invalid_size(test_client): + """Test with invalid size parameter - rejected by Pydantic""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "size": "invalid", + }, + ) + # Pydantic validation errors return 422 (Unprocessable Entity) + # "invalid" has no "x" so Pydantic rejects it + assert response.status_code == 422 + # Check error detail contains size validation message + detail = str(response.json()["detail"]) + assert "size" in detail.lower() or "invalid" in detail.lower() + + +def test_invalid_size_parse_error(test_client): + """Test with malformed size - passes Pydantic but fails parse_size()""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "size": "1024x", # Has "x" so Pydantic accepts, but parse_size() rejects + }, + ) + # parse_size() raises ValueError → endpoint converts to 400 (Bad Request) + assert response.status_code == 400 + detail = str(response.json()["detail"]) + assert "size" in detail.lower() or "invalid" in detail.lower() + + +def test_missing_prompt(test_client): + """Test with missing required prompt field""" + response = test_client.post( + "/v1/images/generations", + json={ + "size": "1024x1024", + }, + ) + # Pydantic validation error + assert response.status_code == 422 + + +def test_invalid_n_parameter(test_client): + """Test with invalid n parameter (out of range)""" + # n < 1 + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "n": 0, + }, + ) + assert response.status_code == 422 + + # n > 10 + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "n": 11, + }, + ) + assert response.status_code == 422 + + +def test_url_response_format_not_supported(test_client): + """Test that URL format returns error""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + "response_format": "url", + }, + ) + # Pydantic validation errors return 422 (Unprocessable Entity) + assert response.status_code == 422 + # Check error mentions response_format or b64_json + detail = str(response.json()["detail"]) + assert "b64_json" in detail.lower() or "response" in detail.lower() + + +def test_model_not_loaded(): + """Test error when diffusion engine is not initialized""" + from fastapi import FastAPI + + from vllm_omni.entrypoints.openai.api_server import router + + app = FastAPI() + app.include_router(router) + # Don't set diffusion_engine to simulate uninitialized state + app.state.diffusion_engine = None + + client = TestClient(app) + response = client.post( + "/v1/images/generations", + json={ + "prompt": "a cat", + }, + ) + assert response.status_code == 503 + assert "not initialized" in response.json()["detail"].lower() + + +def test_different_image_sizes(test_client): + """Test various valid image sizes""" + sizes = ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + + for size in sizes: + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "a test image", + "size": size, + }, + ) + assert response.status_code == 200, f"Failed for size {size}" + + +def test_parameter_validation(): + """Test Pydantic model validation""" + from vllm_omni.entrypoints.openai.protocol.images import ImageGenerationRequest + + # Valid request - optional parameters default to None + req = ImageGenerationRequest(prompt="test") + assert req.prompt == "test" + assert req.n == 1 + assert req.model is None + assert req.size is None # Engine will use model defaults + assert req.num_inference_steps is None # Engine will use model defaults + assert req.true_cfg_scale is None # Engine will use model defaults + + # Invalid num_inference_steps (out of range) + with pytest.raises(ValueError): + ImageGenerationRequest(prompt="test", num_inference_steps=0) + + with pytest.raises(ValueError): + ImageGenerationRequest(prompt="test", num_inference_steps=201) + + # Invalid guidance_scale (out of range) + with pytest.raises(ValueError): + ImageGenerationRequest(prompt="test", guidance_scale=-1.0) + + with pytest.raises(ValueError): + ImageGenerationRequest(prompt="test", guidance_scale=21.0) + + +# Pass-Through Tests + + +def test_parameters_passed_through(test_client, mock_async_diffusion): + """Verify all parameters passed through without modification""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "test", + "num_inference_steps": 100, + "guidance_scale": 7.5, + "true_cfg_scale": 3.0, + "seed": 42, + }, + ) + assert response.status_code == 200 + + # Ensure generate() was called exactly once + mock_async_diffusion.generate.assert_awaited_once() + call_kwargs = mock_async_diffusion.generate.call_args[1]["sampling_params_list"][0] + assert call_kwargs.num_inference_steps == 100 + assert call_kwargs.guidance_scale == 7.5 + assert call_kwargs.true_cfg_scale == 3.0 + assert call_kwargs.seed == 42 + + +def test_model_field_omitted_works(test_client): + """Test that omitting model field works""" + response = test_client.post( + "/v1/images/generations", + json={ + "prompt": "test", + "size": "1024x1024", + # model field omitted + }, + ) + assert response.status_code == 200 + + +def make_test_image_bytes(size=(64, 64)) -> bytes: + img = Image.new( + "RGB", + size, + ) + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def test_image_edit_images_processing(async_omni_test_client): + img_bytes_1 = make_test_image_bytes((16, 16)) + img_bytes_2 = make_test_image_bytes((32, 32)) + + # uploadfile with image key + response = async_omni_test_client.post( + "/v1/images/edits", + files=[ + ("image", img_bytes_1), + ("image", img_bytes_2), + ], + data={"prompt": "hello world."}, + ) + assert response.status_code == 200 + engine = async_omni_test_client.app.state.engine_client + captured_prompt = engine.captured_prompt + processed_images = captured_prompt["multi_modal_data"]["image"] + assert len(processed_images) == 2 + assert isinstance(processed_images[0], Image.Image) + assert isinstance(processed_images[1], Image.Image) + assert processed_images[0].size == (16, 16) + assert processed_images[1].size == (32, 32) + + # uploadfile with image[] key + response = async_omni_test_client.post( + "/v1/images/edits", + files=[ + ("image[]", img_bytes_2), + ("image[]", img_bytes_1), + ], + data={"prompt": "hello world."}, + ) + + assert response.status_code == 200 + engine = async_omni_test_client.app.state.engine_client + captured_prompt = engine.captured_prompt + processed_images = captured_prompt["multi_modal_data"]["image"] + assert len(processed_images) == 2 + assert isinstance(processed_images[0], Image.Image) + assert isinstance(processed_images[1], Image.Image) + assert processed_images[0].size == (32, 32) + assert processed_images[1].size == (16, 16) + + # base64 url + buf1 = io.BytesIO() + img1 = Image.new("RGB", (16, 16)) + img1.save(buf1, format="PNG") + b64_1 = "data:image/png;base64," + base64.b64encode(buf1.getvalue()).decode() + + buf2 = io.BytesIO() + img2 = Image.new("RGB", (24, 24)) + img2.save(buf2, format="PNG") + b64_2 = "data:image/png;base64," + base64.b64encode(buf2.getvalue()).decode() + + response = async_omni_test_client.post( + "/v1/images/edits", + data={ + "prompt": "hello from base64", + "url": [b64_1, b64_2], + }, + ) + assert response.status_code == 200 + processed_images = engine.captured_prompt["multi_modal_data"]["image"] + assert len(processed_images) == 2 + assert isinstance(processed_images[0], Image.Image) + assert isinstance(processed_images[1], Image.Image) + assert processed_images[0].size == (16, 16) + assert processed_images[1].size == (24, 24) + + +def test_image_edit_parameter_pass(async_omni_test_client): + img_bytes_1 = make_test_image_bytes((16, 16)) + + # uploadfile with image key + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "size": "16x24", + "output_format": "jpeg", + "num_inference_steps": 20, + "guidance_scale": 8.0, + "seed": 1234, + "negative_prompt": "negative", + "n": 2, + }, + ) + assert response.status_code == 200 + engine = async_omni_test_client.app.state.engine_client + captured_prompt = engine.captured_prompt + captured_sampling_params = engine.captured_sampling_params_list[-1] + + assert captured_prompt["prompt"] == "hello world." + assert captured_prompt["negative_prompt"] == "negative" + assert captured_sampling_params.num_inference_steps == 20 + assert captured_sampling_params.guidance_scale == 8.0 + assert captured_sampling_params.seed == 1234 + assert captured_sampling_params.num_outputs_per_prompt == 2 + assert captured_sampling_params.width == 16 + assert captured_sampling_params.height == 24 + + data = response.json() + # All images should be valid + for img_data in data["data"]: + assert "b64_json" in img_data + img_bytes = base64.b64decode(img_data["b64_json"]) + img = Image.open(io.BytesIO(img_bytes)) + assert img.format.lower() == "jpeg" + assert data["output_format"] == "jpeg" + assert data["size"] == "16x24" + + +def test_image_edit_parameter_default(async_omni_test_client): + img_bytes_1 = make_test_image_bytes((24, 16)) + + # uploadfile with image key + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "size": "auto", + }, + ) + assert response.status_code == 200 + engine = async_omni_test_client.app.state.engine_client + captured_sampling_params = engine.captured_sampling_params_list[-1] + + assert captured_sampling_params.width == 24 + assert captured_sampling_params.height == 16 + assert captured_sampling_params.num_outputs_per_prompt == 1 + assert captured_sampling_params.num_inference_steps == 4 + assert captured_sampling_params.guidance_scale == 7.5 + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "size": "96x96", + }, + ) + assert response.status_code == 400 + + +def test_image_edit_parameter_default_single_stage(test_client): + img_bytes_1 = make_test_image_bytes((24, 16)) + + # uploadfile with image key + response = test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + }, + ) + assert response.status_code == 200 + engine = test_client.app.state.engine_client + captured_sampling_params = engine.captured_sampling_params_list[0] + + assert captured_sampling_params.width == 24 + assert captured_sampling_params.height == 16 + assert captured_sampling_params.num_outputs_per_prompt == 1 + assert captured_sampling_params.num_inference_steps == 4 + assert captured_sampling_params.guidance_scale == 7.5 + + response = test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "size": "96x96", + }, + ) + assert response.status_code == 400 + + +def test_image_edit_compression_jpeg(test_client): + img_bytes_1 = make_test_image_bytes((16, 16)) + # uploadfile with image key + response = test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={"prompt": "hello world.", "output_format": "jpeg", "output_compression": 100}, + ) + assert response.status_code == 200 + data = response.json() + img_bytes_100 = base64.b64decode(data["data"][0]["b64_json"]) + img = Image.open(io.BytesIO(img_bytes_100)) + assert img.format.lower() == "jpeg" + + response = test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "output_format": "jpeg", + "output_compression": 50, + }, + ) + assert response.status_code == 200 + data = response.json() + img_bytes_50 = base64.b64decode(data["data"][0]["b64_json"]) + + response = test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "output_format": "jpeg", + "output_compression": 10, + }, + ) + assert response.status_code == 200 + data = response.json() + img_bytes_10 = base64.b64decode(data["data"][0]["b64_json"]) + + assert len(img_bytes_10) < len(img_bytes_50) + assert len(img_bytes_50) < len(img_bytes_100) + + +def test_image_edit_compression_png(async_omni_test_client): + img_bytes_1 = make_test_image_bytes((16, 16)) + # uploadfile with image key + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={"prompt": "hello world.", "output_format": "PNG", "output_compression": 100}, + ) + assert response.status_code == 200 + data = response.json() + img_bytes_100 = base64.b64decode(data["data"][0]["b64_json"]) + img = Image.open(io.BytesIO(img_bytes_100)) + assert img.format.lower() == "png" + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "output_format": "PNG", + "output_compression": 50, + }, + ) + assert response.status_code == 200 + data = response.json() + img_bytes_50 = base64.b64decode(data["data"][0]["b64_json"]) + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", img_bytes_1)], + data={ + "prompt": "hello world.", + "output_format": "PNG", + "output_compression": 10, + }, + ) + assert response.status_code == 200 + data = response.json() + img_bytes_10 = base64.b64decode(data["data"][0]["b64_json"]) + + assert len(img_bytes_10) < len(img_bytes_50) + assert len(img_bytes_50) < len(img_bytes_100) diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py new file mode 100644 index 0000000000000000000000000000000000000000..240fd2051edaefe91d2951e90bfe79158494ec32 --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py @@ -0,0 +1,333 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Unit tests for OmniOpenAIServingChat sampling params handling. + +Tests that standard OpenAI API parameters (max_tokens, temperature, etc.) +are correctly applied to the comprehension stage while preserving YAML defaults. +""" + +from unittest.mock import MagicMock + +import pytest +from vllm.sampling_params import SamplingParams + + +@pytest.fixture +def mock_comprehension_stage(): + """Create a mock comprehension stage with is_comprehension=True.""" + stage = MagicMock() + stage.is_comprehension = True + stage.model_stage = "comprehension" + return stage + + +@pytest.fixture +def mock_other_stage(): + """Create a mock non-comprehension stage.""" + stage = MagicMock() + stage.is_comprehension = False + stage.model_stage = "other" + return stage + + +@pytest.fixture +def default_comprehension_params(): + """Default sampling params for comprehension stage (from YAML).""" + return SamplingParams( + temperature=0.4, + top_p=0.9, + top_k=1, + max_tokens=2048, + seed=42, + repetition_penalty=1.05, + ) + + +@pytest.fixture +def default_other_params(): + """Default sampling params for non-comprehension stage (from YAML).""" + return SamplingParams( + temperature=0.9, + top_k=50, + max_tokens=4096, + seed=42, + ) + + +@pytest.fixture +def mock_engine_client(mock_comprehension_stage, mock_other_stage, default_comprehension_params, default_other_params): + """Create mock engine client with stage_list and default_sampling_params_list.""" + engine_client = MagicMock() + engine_client.stage_list = [mock_comprehension_stage, mock_other_stage] + engine_client.default_sampling_params_list = [ + default_comprehension_params, + default_other_params, + ] + return engine_client + + +@pytest.fixture +def serving_chat(mock_engine_client): + """Create OmniOpenAIServingChat instance with mocked dependencies.""" + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + # Create instance without calling __init__ + instance = object.__new__(OmniOpenAIServingChat) + instance.engine_client = mock_engine_client + return instance + + +@pytest.fixture +def mock_request(): + """Create a mock request with all OpenAI sampling params set to None.""" + request = MagicMock() + # OpenAI standard sampling fields + request.temperature = None + request.top_p = None + request.max_tokens = None + request.seed = None + request.stop = None + request.frequency_penalty = None + request.presence_penalty = None + return request + + +# ============================================================================= +# Tests for _OPENAI_SAMPLING_FIELDS constant +# ============================================================================= + + +def test_openai_sampling_fields_contains_expected_fields(): + """Test that _OPENAI_SAMPLING_FIELDS contains all expected OpenAI params.""" + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + expected_fields = { + "temperature", + "top_p", + "max_tokens", + "seed", + "stop", + "frequency_penalty", + "presence_penalty", + } + assert OmniOpenAIServingChat._OPENAI_SAMPLING_FIELDS == expected_fields + + +# ============================================================================= +# Tests for _build_sampling_params_list_from_request +# ============================================================================= + + +def test_preserves_yaml_defaults_when_no_request_params(serving_chat, mock_request): + """Test that YAML defaults are preserved when request has no params.""" + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + assert len(result) == 2 + comprehension_params = result[0] + assert comprehension_params.temperature == 0.4 + assert comprehension_params.top_p == 0.9 + assert comprehension_params.top_k == 1 # YAML custom param preserved + assert comprehension_params.max_tokens == 2048 + assert comprehension_params.seed == 42 + assert comprehension_params.repetition_penalty == 1.05 # YAML custom param preserved + + +def test_request_temperature_overrides_yaml_default(serving_chat, mock_request): + """Test that request temperature overrides YAML default.""" + mock_request.temperature = 0.8 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + comprehension_params = result[0] + assert comprehension_params.temperature == 0.8 # Overridden + assert comprehension_params.seed == 42 # Preserved from YAML + assert comprehension_params.top_k == 1 # YAML custom param preserved + + +def test_request_top_p_overrides_yaml_default(serving_chat, mock_request): + """Test that request top_p overrides YAML default.""" + mock_request.top_p = 0.95 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + comprehension_params = result[0] + assert comprehension_params.top_p == 0.95 # Overridden + assert comprehension_params.temperature == 0.4 # Preserved from YAML + + +def test_request_max_tokens_overrides_yaml_default(serving_chat, mock_request): + """Test that request max_tokens overrides YAML default.""" + mock_request.max_tokens = 100 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + assert result[0].max_tokens == 100 + + +def test_max_tokens_uses_yaml_default_when_not_specified(serving_chat, mock_request): + """Test that max_tokens falls back to YAML default when not in request.""" + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + assert result[0].max_tokens == 2048 + + +def test_request_seed_overrides_yaml_default(serving_chat, mock_request): + """Test that request seed overrides YAML default.""" + mock_request.seed = 123 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + comprehension_params = result[0] + assert comprehension_params.seed == 123 # Overridden + assert comprehension_params.temperature == 0.4 # Preserved from YAML + + +def test_request_frequency_penalty_overrides(serving_chat, mock_request): + """Test that request frequency_penalty is applied.""" + mock_request.frequency_penalty = 0.5 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + assert result[0].frequency_penalty == 0.5 + + +def test_request_presence_penalty_overrides(serving_chat, mock_request): + """Test that request presence_penalty is applied.""" + mock_request.presence_penalty = 0.3 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + assert result[0].presence_penalty == 0.3 + + +def test_non_comprehension_stages_use_cloned_defaults(serving_chat, mock_request): + """Test that non-comprehension stages always use cloned YAML defaults.""" + mock_request.max_tokens = 50 + mock_request.temperature = 0.1 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + other_params = result[1] + assert other_params.temperature == 0.9 # YAML default (not affected by request) + assert other_params.max_tokens == 4096 # YAML default (not affected by request) + assert other_params.top_k == 50 # YAML default + assert other_params.seed == 42 # YAML default + + +def test_multiple_params_override_together(serving_chat, mock_request): + """Test that multiple request params can override together.""" + mock_request.max_tokens = 200 + mock_request.temperature = 0.7 + mock_request.top_p = 0.85 + mock_request.seed = 999 + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + comprehension_params = result[0] + # Overridden by request + assert comprehension_params.temperature == 0.7 + assert comprehension_params.top_p == 0.85 + assert comprehension_params.max_tokens == 200 + assert comprehension_params.seed == 999 + # Preserved from YAML (not in _OPENAI_SAMPLING_FIELDS) + assert comprehension_params.top_k == 1 + assert comprehension_params.repetition_penalty == 1.05 + + +def test_yaml_custom_params_not_overridden_by_request(serving_chat, mock_request): + """Test that YAML custom params (top_k, repetition_penalty) are not affected.""" + # Even if request has these attributes, they should not override YAML + # because they're not in _OPENAI_SAMPLING_FIELDS + mock_request.top_k = 100 # Not in allowlist + mock_request.repetition_penalty = 2.0 # Not in allowlist + + result = serving_chat._build_sampling_params_list_from_request(mock_request) + + comprehension_params = result[0] + assert comprehension_params.top_k == 1 # YAML default preserved + assert comprehension_params.repetition_penalty == 1.05 # YAML default preserved + + +# ============================================================================= +# Tests for _apply_request_overrides +# ============================================================================= + + +def test_apply_request_overrides_clones_params(serving_chat, mock_request, default_comprehension_params): + """Test that _apply_request_overrides returns a cloned object.""" + result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request) + + assert result is not default_comprehension_params # Different object + + +def test_apply_request_overrides_preserves_defaults(serving_chat, mock_request, default_comprehension_params): + """Test that _apply_request_overrides preserves defaults when request has None.""" + result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request) + + assert result.temperature == 0.4 + assert result.top_p == 0.9 + assert result.seed == 42 + assert result.top_k == 1 # YAML custom param + + +def test_apply_request_overrides_applies_values(serving_chat, mock_request, default_comprehension_params): + """Test that _apply_request_overrides applies non-None request values.""" + mock_request.temperature = 0.8 + mock_request.seed = 123 + + result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request) + + assert result.temperature == 0.8 # Overridden + assert result.seed == 123 # Overridden + assert result.top_p == 0.9 # Preserved from default + assert result.top_k == 1 # YAML custom param preserved + + +# ============================================================================= +# Tests for _get_comprehension_stage_index +# ============================================================================= + + +def test_get_comprehension_stage_index_finds_first_stage(mock_engine_client): + """Test finding comprehension stage when it's at index 0.""" + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + instance = object.__new__(OmniOpenAIServingChat) + instance.engine_client = mock_engine_client + + assert instance._get_comprehension_stage_index() == 0 + + +def test_get_comprehension_stage_index_finds_second_stage(): + """Test finding comprehension stage when it's at index 1.""" + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + instance = object.__new__(OmniOpenAIServingChat) + + other = MagicMock() + other.is_comprehension = False + comprehension = MagicMock() + comprehension.is_comprehension = True + + instance.engine_client = MagicMock() + instance.engine_client.stage_list = [other, comprehension] + + assert instance._get_comprehension_stage_index() == 1 + + +def test_get_comprehension_stage_index_raises_when_not_found(): + """Test that ValueError is raised when no comprehension stage exists.""" + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + instance = object.__new__(OmniOpenAIServingChat) + + stage1 = MagicMock() + stage1.is_comprehension = False + stage2 = MagicMock() + stage2.is_comprehension = False + + instance.engine_client = MagicMock() + instance.engine_client.stage_list = [stage1, stage2] + + with pytest.raises(ValueError, match="No comprehension stage"): + instance._get_comprehension_stage_index() diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..f82650b8fc35a1c9ae20fa1022643b1918dbd04d --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -0,0 +1,473 @@ +# tests/entrypoints/openai/test_serving_speech.py +import logging +from inspect import Signature, signature +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin +from vllm_omni.entrypoints.openai.protocol.audio import CreateAudio, OpenAICreateSpeechRequest +from vllm_omni.entrypoints.openai.serving_speech import ( + OmniOpenAIServingSpeech, +) +from vllm_omni.outputs import OmniRequestOutput + +logger = logging.getLogger(__name__) + + +class TestAudioMixin: + @pytest.fixture + def audio_mixin(self): + return AudioMixin() + + def test_stereo_to_mono_conversion(self, audio_mixin): + stereo_tensor = np.random.rand(24000, 2).astype(np.float32) + audio_obj = CreateAudio(audio_tensor=stereo_tensor) + + with ( + patch.object( + audio_mixin, "_apply_speed_adjustment", side_effect=lambda tensor, speed, sr: (tensor, sr) + ) as mock_speed, + patch("soundfile.write") as _, + ): + audio_mixin.create_audio(audio_obj) + + # Check that the tensor passed to speed adjustment is mono + mock_speed.assert_called_once() + adjusted_tensor = mock_speed.call_args[0][0] + assert len(adjusted_tensor) == 24000 + + @patch("librosa.effects.time_stretch") + def test_speed_adjustment(self, mock_time_stretch, audio_mixin): + mock_time_stretch.return_value = np.zeros(12000) + audio_tensor = np.random.rand(24000).astype(np.float32) + + adjusted_audio, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=2.0, sample_rate=24000) + + mock_time_stretch.assert_called_with(y=audio_tensor, rate=2.0) + assert adjusted_audio.shape == (12000,) + + @patch("soundfile.write") + def test_unsupported_format_fallback(self, mock_write, audio_mixin, caplog): + audio_tensor = np.random.rand(24000).astype(np.float32) + # Use a format that is not in the list of supported formats + audio_obj = CreateAudio(audio_tensor=audio_tensor, response_format="vorbis") + + audio_mixin.create_audio(audio_obj) + + # Should fall back to 'wav' + mock_write.assert_called_once() + write_kwargs = mock_write.call_args.kwargs + assert write_kwargs["format"] == "WAV" + + def test_mono_audio_preservation(self, audio_mixin): + """Test that mono (1D) audio tensors are processed correctly and passed to writer.""" + mono_tensor = np.random.rand(24000).astype(np.float32) + audio_obj = CreateAudio(audio_tensor=mono_tensor) + + with patch("soundfile.write") as mock_write: + audio_mixin.create_audio(audio_obj) + + mock_write.assert_called_once() + # Verify the tensor passed to soundfile.write is the exact 1D tensor + output_tensor = mock_write.call_args[0][1] + assert output_tensor.ndim == 1 + assert output_tensor.shape == (24000,) + assert np.array_equal(output_tensor, mono_tensor) + + def test_stereo_audio_preservation(self, audio_mixin): + """Test that stereo (2D) audio tensors are processed correctly and preserved.""" + stereo_tensor = np.random.rand(24000, 2).astype(np.float32) + audio_obj = CreateAudio(audio_tensor=stereo_tensor) + + with patch("soundfile.write") as mock_write: + audio_mixin.create_audio(audio_obj) + + mock_write.assert_called_once() + # Verify the tensor passed to soundfile.write is the exact 2D tensor + output_tensor = mock_write.call_args[0][1] + assert output_tensor.ndim == 2 + assert output_tensor.shape == (24000, 2) + assert np.array_equal(output_tensor, stereo_tensor) + + def test_speed_adjustment_bypass(self, audio_mixin): + """Test that speed=1.0 bypasses the expensive librosa time stretching.""" + audio_tensor = np.random.rand(24000).astype(np.float32) + + with patch("librosa.effects.time_stretch") as mock_time_stretch: + # speed=1.0 should return immediately without calling librosa + result, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=1.0, sample_rate=24000) + + mock_time_stretch.assert_not_called() + assert np.array_equal(result, audio_tensor) + + @patch("librosa.effects.time_stretch") + def test_speed_adjustment_stereo_handling(self, mock_time_stretch, audio_mixin): + """Test that speed adjustment is attempted on stereo inputs.""" + stereo_tensor = np.random.rand(24000, 2).astype(np.float32) + # Mock return value representing a sped-up version (half length) + mock_time_stretch.return_value = np.zeros((12000, 2), dtype=np.float32) + + result, _ = audio_mixin._apply_speed_adjustment(stereo_tensor, speed=2.0, sample_rate=24000) + + mock_time_stretch.assert_called_once() + # Ensure the stereo tensor was passed to librosa + call_args = mock_time_stretch.call_args + assert np.array_equal(call_args.kwargs["y"], stereo_tensor) + assert call_args.kwargs["rate"] == 2.0 + assert result.shape == (12000, 2) + + +# Helper to create mock model output for endpoint tests +def create_mock_audio_output_for_test( + request_id: str = "speech-mock-123", +) -> OmniRequestOutput: + class MockCompletionOutput: + def __init__(self, index: int = 0): + self.index = index + self.text = "" + self.token_ids = [] + self.finish_reason = "stop" + self.stop_reason = None + self.logprobs = None + + class MockRequestOutput: + def __init__(self, request_id: str, audio_tensor: torch.Tensor): + self.request_id = request_id + self.outputs = [MockCompletionOutput(index=0)] + self.multimodal_output = {"audio": audio_tensor} + self.finished = True + self.prompt_token_ids = None + self.encoder_prompt_token_ids = None + self.num_cached_tokens = None + self.prompt_logprobs = None + self.kv_transfer_params = None + + num_samples = 24000 + audio_tensor = torch.sin(torch.linspace(0, 440 * 2 * torch.pi, num_samples)) + mock_request_output = MockRequestOutput(request_id=request_id, audio_tensor=audio_tensor) + + return OmniRequestOutput( + stage_id=0, + final_output_type="audio", + request_output=mock_request_output, + ) + + +def create_mock_audio_output_on_completion_for_test( + request_id: str = "speech-mock-completion-123", +) -> OmniRequestOutput: + class MockCompletionOutput: + def __init__(self, audio_tensor: torch.Tensor, index: int = 0): + self.index = index + self.text = "" + self.token_ids = [] + self.finish_reason = "stop" + self.stop_reason = None + self.logprobs = None + self.multimodal_output = {"audio": audio_tensor, "sr": 24000} + + class MockRequestOutput: + def __init__(self, request_id: str, audio_tensor: torch.Tensor): + self.request_id = request_id + self.outputs = [MockCompletionOutput(audio_tensor=audio_tensor, index=0)] + self.multimodal_output = {} + self.finished = True + self.prompt_token_ids = None + self.encoder_prompt_token_ids = None + self.num_cached_tokens = None + self.prompt_logprobs = None + self.kv_transfer_params = None + + num_samples = 24000 + audio_tensor = torch.sin(torch.linspace(0, 440 * 2 * torch.pi, num_samples)) + mock_request_output = MockRequestOutput(request_id=request_id, audio_tensor=audio_tensor) + + return OmniRequestOutput( + stage_id=0, + final_output_type="audio", + request_output=mock_request_output, + ) + + +@pytest.fixture +def test_app(): + # Mock the engine client + mock_engine_client = MagicMock() + mock_engine_client.errored = False + + async def mock_generate_fn(*args, **kwargs): + yield create_mock_audio_output_for_test(request_id=kwargs.get("request_id")) + + mock_engine_client.generate = MagicMock(side_effect=mock_generate_fn) + mock_engine_client.default_sampling_params_list = [{}] + + # Mock models to have an is_base_model method + mock_models = MagicMock() + mock_models.is_base_model.return_value = True + + mock_request_logger = MagicMock() + + speech_server = OmniOpenAIServingSpeech( + engine_client=mock_engine_client, + models=mock_models, + request_logger=mock_request_logger, + ) + + # Patch the signature of create_speech to remove 'raw_request' for FastAPI route introspection + original_create_speech = speech_server.create_speech + _ = MagicMock(side_effect=original_create_speech) + + sig = signature(original_create_speech) + + new_parameters = [param for name, param in sig.parameters.items() if name != "raw_request"] + + new_sig = Signature(parameters=new_parameters, return_annotation=sig.return_annotation) + + async def awaitable_patched_create_speech(*args, **kwargs): + return await original_create_speech(*args, **kwargs) + + awaitable_patched_create_speech.__signature__ = new_sig + speech_server.create_speech = awaitable_patched_create_speech + + app = FastAPI() + app.add_api_route("/v1/audio/speech", speech_server.create_speech, methods=["POST"], response_model=None) + + # Add list_voices endpoint + async def list_voices(): + speakers = sorted(speech_server.supported_speakers) if speech_server.supported_speakers else [] + return {"voices": speakers} + + app.add_api_route("/v1/audio/voices", list_voices, methods=["GET"]) + + return app + + +@pytest.fixture +def client(test_app): + return TestClient(test_app) + + +class TestSpeechAPI: + def test_create_speech_success(self, client): + payload = { + "input": "Hello world", + "model": "tts-model", + "voice": "alloy", + "response_format": "wav", + } + response = client.post("/v1/audio/speech", json=payload) + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert len(response.content) > 0 + + def test_create_speech_mp3_format(self, client): + payload = { + "input": "Hello world", + "model": "tts-model", + "voice": "alloy", + "response_format": "mp3", + } + response = client.post("/v1/audio/speech", json=payload) + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mpeg" + assert len(response.content) > 0 + + def test_create_speech_reads_audio_from_completion_output(self, test_app): + mock_engine_client = MagicMock() + mock_engine_client.errored = False + async def mock_generate_fn(*args, **kwargs): + yield create_mock_audio_output_on_completion_for_test(request_id=kwargs.get("request_id")) + + mock_engine_client.generate = MagicMock(side_effect=mock_generate_fn) + mock_engine_client.default_sampling_params_list = [{}] + + mock_models = MagicMock() + mock_models.is_base_model.return_value = True + + speech_server = OmniOpenAIServingSpeech( + engine_client=mock_engine_client, + models=mock_models, + request_logger=MagicMock(), + ) + + original_create_speech = speech_server.create_speech + sig = signature(original_create_speech) + new_parameters = [param for name, param in sig.parameters.items() if name != "raw_request"] + new_sig = Signature(parameters=new_parameters, return_annotation=sig.return_annotation) + + async def awaitable_patched_create_speech(*args, **kwargs): + return await original_create_speech(*args, **kwargs) + + awaitable_patched_create_speech.__signature__ = new_sig + speech_server.create_speech = awaitable_patched_create_speech + + app = FastAPI() + app.add_api_route("/v1/audio/speech", speech_server.create_speech, methods=["POST"], response_model=None) + + client = TestClient(app) + payload = { + "input": "Hello world", + "model": "tts-model", + "voice": "alloy", + "response_format": "wav", + } + response = client.post("/v1/audio/speech", json=payload) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert len(response.content) > 0 + + def test_create_speech_invalid_format(self, client): + payload = { + "input": "Hello world", + "model": "tts-model", + "voice": "alloy", + "response_format": "invalid_format", + } + response = client.post("/v1/audio/speech", json=payload) + assert response.status_code == 422 # Unprocessable Entity + + @patch("vllm_omni.entrypoints.openai.serving_speech.OmniOpenAIServingSpeech.create_audio") + def test_speed_parameter_is_used(self, mock_create_audio, test_app): + client = TestClient(test_app) + + mock_audio_response = MagicMock() + mock_audio_response.audio_data = b"dummy_audio" + mock_audio_response.media_type = "audio/wav" + mock_create_audio.return_value = mock_audio_response + + payload = { + "input": "This should be fast.", + "model": "tts-model", + "voice": "alloy", + "response_format": "wav", + "speed": 2.5, + } + client.post("/v1/audio/speech", json=payload) + + mock_create_audio.assert_called_once() + call_args = mock_create_audio.call_args[0] + audio_obj = call_args[0] + assert isinstance(audio_obj, CreateAudio) + assert audio_obj.speed == 2.5 + + def test_list_voices_endpoint(self, client): + response = client.get("/v1/audio/voices") + assert response.status_code == 200 + assert "voices" in response.json() + + +class TestTTSMethods: + """Unit tests for TTS validation and parameter building.""" + + @pytest.fixture + def speech_server(self): + mock_engine_client = MagicMock() + mock_engine_client.errored = False + mock_engine_client.stage_list = None + mock_models = MagicMock() + mock_models.is_base_model.return_value = True + return OmniOpenAIServingSpeech( + engine_client=mock_engine_client, + models=mock_models, + request_logger=MagicMock(), + ) + + def test_is_tts_model(self, speech_server): + """Test TTS model detection.""" + # No stage_list -> False + assert speech_server._is_tts_model() is False + + # With qwen3_tts stage -> True + mock_stage = MagicMock() + mock_stage.model_stage = "qwen3_tts" + speech_server.engine_client.stage_list = [mock_stage] + assert speech_server._is_tts_model() is True + + def test_build_tts_prompt(self, speech_server): + """Test TTS prompt format.""" + prompt = speech_server._build_tts_prompt("Hello") + assert prompt == "<|im_start|>assistant\nHello<|im_end|>\n<|im_start|>assistant\n" + + def test_validate_tts_request_basic(self, speech_server): + """Test basic validation cases.""" + # Empty input + req = OpenAICreateSpeechRequest(input="") + assert speech_server._validate_tts_request(req) == "Input text cannot be empty" + + # Invalid language + req = OpenAICreateSpeechRequest(input="Hello", language="InvalidLang") + assert "Invalid language" in speech_server._validate_tts_request(req) + + # When no speakers loaded, any voice is accepted (unconstrained) + req = OpenAICreateSpeechRequest(input="Hello", voice="Invalid") + assert speech_server._validate_tts_request(req) is None + + # Valid request + req = OpenAICreateSpeechRequest(input="Hello", voice="Vivian") + assert speech_server._validate_tts_request(req) is None + + def test_validate_tts_request_task_types(self, speech_server): + """Test task-specific validation.""" + # Base task requires ref_audio + req = OpenAICreateSpeechRequest(input="Hello", task_type="Base") + assert "ref_audio" in speech_server._validate_tts_request(req) + + # VoiceDesign requires instructions + req = OpenAICreateSpeechRequest(input="Hello", task_type="VoiceDesign") + assert "instructions" in speech_server._validate_tts_request(req) + + # ref_text only for Base + req = OpenAICreateSpeechRequest(input="Hello", ref_text="text") + assert "Base task" in speech_server._validate_tts_request(req) + + def test_build_tts_params(self, speech_server): + """Test TTS parameter building.""" + req = OpenAICreateSpeechRequest(input="Hello", voice="Ryan", language="English") + params = speech_server._build_tts_params(req) + + assert params["text"] == ["Hello"] + assert params["speaker"] == ["Ryan"] + assert params["language"] == ["English"] + assert params["task_type"] == ["CustomVoice"] + assert "max_new_tokens" not in params + + def test_build_tts_params_with_explicit_max_new_tokens(self, speech_server): + """Test explicit max_new_tokens override.""" + req = OpenAICreateSpeechRequest( + input="Hello", + task_type="Base", + ref_audio="data:audio/wav;base64,AAAA", + max_new_tokens=128, + ) + params = speech_server._build_tts_params(req) + + assert params["max_new_tokens"] == [128] + + def test_load_supported_speakers(self): + """Test _load_supported_speakers.""" + mock_engine_client = MagicMock() + mock_engine_client.errored = False + mock_engine_client.stage_list = None + + # Mock talker_config with mixed-case speaker names + mock_talker_config = MagicMock() + mock_talker_config.spk_id = {"Ryan": 0, "Vivian": 1, "Aiden": 2} + mock_engine_client.model_config.hf_config.talker_config = mock_talker_config + + mock_models = MagicMock() + mock_models.is_base_model.return_value = True + + server = OmniOpenAIServingSpeech( + engine_client=mock_engine_client, + models=mock_models, + request_logger=MagicMock(), + ) + + # Verify speakers are normalized to lowercase + assert server.supported_speakers == {"ryan", "vivian", "aiden"} diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6b49eba2c60b4cedf9880fd03a2ab4532ac4c65a --- /dev/null +++ b/tests/entrypoints/test_async_omni_diffusion_config.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.entrypoints import omni as omni_module +from vllm_omni.entrypoints.async_omni import AsyncOmni + + +def test_default_stage_config_includes_cache_backend(monkeypatch): + """Ensure cache_backend/cache_config are preserved in default diffusion stage.""" + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) + monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None) + monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) + monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) + + omni = AsyncOmni( + model="dummy-model", + cache_backend="cache_dit", + cache_config='{"Fn_compute_blocks": 2}', + vae_use_slicing=True, + ulysses_degree=2, + ) + + stage_cfg = omni.stage_configs[0] + engine_args = stage_cfg.engine_args + + assert engine_args.get("cache_backend") == "cache_dit" + cache_config = engine_args.get("cache_config") + assert cache_config["Fn_compute_blocks"] == 2 + assert engine_args.get("vae_use_slicing") is True + parallel_config = engine_args.get("parallel_config") + if hasattr(parallel_config, "get"): + ulysses_degree = parallel_config.get("ulysses_degree") + else: + ulysses_degree = getattr(parallel_config, "ulysses_degree", None) + assert ulysses_degree == 2 + + +def test_default_cache_config_used_when_missing(monkeypatch): + """Ensure default cache_config is applied when cache_backend is set.""" + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) + monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None) + monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) + monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) + + omni = AsyncOmni( + model="dummy-model", + cache_backend="cache_dit", + ) + + engine_args = omni.stage_configs[0].engine_args + cache_config = engine_args.get("cache_config") + assert cache_config is not None + assert cache_config["Fn_compute_blocks"] == 1 + + +def test_default_stage_devices_from_sequence_parallel(monkeypatch): + """Ensure devices list reflects sequence parallel size when no parallel_config is provided.""" + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) + monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None) + monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) + monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) + + omni = AsyncOmni( + model="dummy-model", + ulysses_degree=2, + ring_degree=2, + ) + + stage_cfg = omni.stage_configs[0] + runtime = stage_cfg.runtime + if hasattr(runtime, "get"): + devices = runtime.get("devices") + else: + devices = getattr(runtime, "devices", None) + assert devices == "0,1,2,3" diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c4884e3abd152474b62bf7f595b62f2a887ac629 --- /dev/null +++ b/tests/entrypoints/test_omni_diffusion.py @@ -0,0 +1,1103 @@ +import uuid +import warnings +from queue import Empty, Queue +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPy.*has no __module__ attribute", + category=DeprecationWarning, +) + + +class _FakeEngineArgs(dict): + """Fake engine args that can be used both as object attributes and as **kwargs.""" + + def __init__(self, args_dict: dict[str, Any]): + super().__init__(args_dict) + # Add required attributes if not present + if "model_stage" not in self: + self["model_stage"] = None + if "engine_output_type" not in self: + self["engine_output_type"] = None + # Also set as attributes for object-style access + for key, value in self.items(): + setattr(self, key, value) + + +class _FakeStageConfig: + """Fake stage config object that mimics the real stage config structure.""" + + def __init__(self, config_dict: dict[str, Any]): + # engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs) + engine_args_dict = config_dict.get("engine_args", {}) + self.engine_args = _FakeEngineArgs(engine_args_dict) + self.final_output = config_dict.get("final_output", False) + self.final_output_type = config_dict.get("final_output_type", None) + self.stage_id = config_dict.get("stage_id", 0) + # Store original dict for reference + self._config_dict = config_dict + + +class _FakeQueue: + """Fake queue using standard library Queue to replace mp.Queue.""" + + def __init__(self, maxsize=0): + self._queue = Queue(maxsize=maxsize) + + def put(self, item): + self._queue.put(item) + + def put_nowait(self, item): + self._queue.put_nowait(item) + + def get(self): + return self._queue.get() + + def get_nowait(self): + return self._queue.get_nowait() + + def empty(self): + return self._queue.empty() + + +class _FakeStage: + """Lightweight Stage stub for multi-process pipeline version with queue support.""" + + def __init__(self, config, stage_init_timeout: int = 300): + # Handle both dict and object configs + if isinstance(config, dict): + config = _FakeStageConfig(config) + self.config = config + self.stage_config = config + self.engine = None + self.engine_outputs = None + # Set attributes that OmniStage expects + self.stage_id = getattr(config, "stage_id", 0) + self.engine_args = config.engine_args + self.model_stage = getattr(config.engine_args, "model_stage", None) + self.stage_type = "diffusion" + # set default sampling params + self.default_sampling_params = OmniDiffusionSamplingParams(num_inference_steps=1) + # Allow configuring final_output and final_output_type + self.final_output = config.final_output if hasattr(config, "final_output") else False + self.final_output_type = getattr(config, "final_output_type", None) + # Configurable processing logic, default returns placeholder + processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"]) + self._processed_input = processed_input + # Queue references (set by attach_queues) + self._in_q = None + self._out_q = None + self._proc = None # Mock process reference + self._stage_init_timeout = max(0, int(stage_init_timeout)) + + def attach_queues(self, in_q, out_q): + """Attach input and output queues.""" + self._in_q = in_q + self._out_q = out_q + + def init_stage_worker( + self, + model: str, + *, + is_async: bool = False, + shm_threshold_bytes: int = 65536, + ctx=None, + batch_timeout: int = 10, + **kwargs, + ): + """Mock init_stage_worker: don't start real process, just send stage_ready message.""" + # Create a mock process object + self._proc = MagicMock() + self._proc.start = MagicMock() + self._proc.join = MagicMock() + self._proc.is_alive = MagicMock(return_value=False) + self._proc.terminate = MagicMock() + # Send stage_ready message to output queue + if self._out_q is not None: + try: + self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id}) + except Exception: + pass + + def stop_stage_worker(self): + """Mock stop_stage_worker: clean up queue references.""" + if self._in_q is not None: + try: + self._in_q.put_nowait(SHUTDOWN_TASK) + except Exception: + pass + + def submit(self, payload: dict[str, Any]): + """Submit task to input queue.""" + if self._in_q is not None: + self._in_q.put(payload) + + def try_collect(self) -> Any: + """Non-blocking collect from output queue.""" + if self._out_q is None: + return None + try: + return self._out_q.get_nowait() + except Empty: + return None + + def set_engine_outputs(self, outputs): + """Set engine outputs for the stage.""" + self.engine_outputs = outputs + + def process_engine_inputs(self, stage_list, prompts): + """Process engine inputs: return preset processed result.""" + return self._processed_input + + +class _FakeEngine: + """Lightweight Engine stub: provides generate iterator output.""" + + def __init__(self, outputs: list[Any]): + self._outputs = outputs + + def generate(self, prompts, sampling_params): + # Record the most recent prompts for outer assertions + self._last_prompts = prompts + # Simplified: return preset list at once, ensuring iterability + yield from self._outputs + + +@pytest.fixture +def fake_stage_config(): + return { + # Don't include 'model' in engine_args since it's passed separately + "engine_args": {}, + "final_output": True, + "final_output_type": "text", + # Second stage will use processed_input to verify the chain + "processed_input": ["processed-by-stage"], + } + + +def _setup_engine_mocks(monkeypatch): + """Helper function to set up common engine mocks.""" + fake_engine = MagicMock() + # Add necessary attributes to fake_engine + fake_engine.tokenizer = MagicMock() + fake_engine.log_stats = False + fake_engine.vllm_config = MagicMock() + fake_engine.vllm_config.model_config = MagicMock() + fake_engine.vllm_config.model_config.io_processor_plugin = None + fake_engine.get_supported_tasks = MagicMock(return_value=[]) + fake_engine.model_config = MagicMock() + fake_engine.model_config.io_processor_plugin = None + # Add registry with resolve_model_cls method + fake_registry = MagicMock() + fake_registry.resolve_model_cls = MagicMock(return_value=(MagicMock(), "test_arch")) + fake_engine.model_config.registry = fake_registry + fake_engine.vllm_config.model_config.registry = fake_registry + + monkeypatch.setattr( + "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args", + lambda **kw: fake_engine, + raising=False, + ) + + # Mock model_config.registry.resolve_model_cls to return a tuple + # Use a real class instead of MagicMock to avoid inspect.getsource issues + class FakeModelClass: + pass + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils.get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils._get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + + # Mock try_create_mm_pooling_model_cls to return the class as-is + monkeypatch.setattr( + "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls", + lambda model_cls: model_cls, + raising=False, + ) + + # Mock _enable_processor_cache to return False + monkeypatch.setattr( + "vllm.multimodal.cache._enable_processor_cache", + lambda model_config, mm_registry: False, + raising=False, + ) + + # Mock get_io_processor to return None + monkeypatch.setattr( + "vllm.plugins.io_processors.get_io_processor", + lambda vllm_config, io_processor_plugin: None, + raising=False, + ) + + +def _setup_multiprocessing_mocks(monkeypatch): + """Helper function to set up multiprocessing mocks.""" + import multiprocessing as mp + + # Mock Process + fake_process_class = MagicMock() + fake_process_instance = MagicMock() + fake_process_instance.start = MagicMock() + fake_process_instance.join = MagicMock() + fake_process_instance.is_alive = MagicMock(return_value=False) + fake_process_instance.terminate = MagicMock() + fake_process_class.return_value = fake_process_instance + + # Mock get_context to return a context with Queue that returns _FakeQueue + fake_ctx = MagicMock() + fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize) + fake_ctx.Process = fake_process_class + + def _mock_get_context(method): + return fake_ctx + + monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False) + monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) + + +def _setup_ipc_mocks(monkeypatch): + """Helper function to set up IPC function mocks.""" + + # Mock _encode: simple serialization + def _fake_encode(obj, threshold, obj_key, shm_key): + return {obj_key: obj} + + # Mock _load: extract object from result + def _fake_load(result, obj_key, shm_key): + return result.get(obj_key) + + # Mock _set: calculate serialization size + def _fake_set(obj): + return str(obj).encode() + + monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False) + + +def _setup_log_mocks(monkeypatch): + """Helper function to set up logging and stats mocks.""" + # Mock OrchestratorMetrics to be a simple class that doesn't require file operations + + class _FakeOrchestratorMetrics: + def __init__(self, num_stages, enable_stats, wall_start_ts): + self.num_stages = num_stages + self.enable_stats = enable_stats + self.stage_first_ts = [None] * num_stages + self.stage_last_ts = [None] * num_stages + self.e2e_done = set() + + def on_stage_metrics(self, stage_id, req_id, metrics): + pass + + def on_finalize_request(self, stage_id, req_id, start_ts): + self.e2e_done.add(req_id) + + def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm): + pass + + def build_and_log_summary(self, final_stage_id): + return "Fake summary" + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni.OrchestratorMetrics", + _FakeOrchestratorMetrics, + raising=False, + ) + + +@pytest.fixture(autouse=True) +def mock_get_config(monkeypatch): + """Auto-mock get_config and related model loading functions to avoid model path validation.""" + # CRITICAL: Mock tokenizer-related imports FIRST, before any module imports + # This prevents ImportError when async_omni is imported (which happens via omni_stage) + import sys + + fake_tokenizer = MagicMock() + fake_tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + fake_tokenizer.decode = MagicMock(return_value="test") + + # Mock init_tokenizer_from_configs (used in async_omni) + def _mock_init_tokenizer_from_configs(model_config=None, **kwargs): + return fake_tokenizer + + # Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer) + # This works if the module hasn't been imported yet + monkeypatch.setattr( + "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + + # Strategy 2: If the module is already in sys.modules, patch it directly + tokenizer_module_path = "vllm.transformers_utils.tokenizer" + if tokenizer_module_path in sys.modules: + tokenizer_module = sys.modules[tokenizer_module_path] + setattr(tokenizer_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + # CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni + + # This is because async_omni imports processor.py, which imports this function at module level + # Mock length_from_prompt_token_ids_or_embeds (used in processor.py) + def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None): + # Return a reasonable default length + if prompt_token_ids is not None: + if isinstance(prompt_token_ids, list): + return len(prompt_token_ids) + elif hasattr(prompt_token_ids, "shape"): + return prompt_token_ids.shape[-1] if len(prompt_token_ids.shape) > 0 else 1 + if prompt_embeds is not None: + if hasattr(prompt_embeds, "shape"): + return prompt_embeds.shape[-2] if len(prompt_embeds.shape) > 1 else 1 + return 10 # Default length + + # Mock in vllm.utils + monkeypatch.setattr( + "vllm.utils.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + # Also mock in processor module if it's imported + monkeypatch.setattr( + "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + # If processor module is already imported, patch it directly + processor_module_path = "vllm_omni.engine.input_processor" + if processor_module_path in sys.modules: + processor_module = sys.modules[processor_module_path] + setattr( + processor_module, "length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds + ) + + # Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked + # This prevents ImportError when async_omni imports processor.py + monkeypatch.setattr( + "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + + # Strategy 4: If async_omni is already imported, patch it directly + async_omni_path = "vllm_omni.entrypoints.async_omni" + if async_omni_path in sys.modules: + async_omni_module = sys.modules[async_omni_path] + setattr(async_omni_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + # Now mock get_config and other functions + fake_hf_config = MagicMock() + fake_hf_config.model_type = "qwen2_5_omni" + + def _mock_get_config(model, **kwargs): + return fake_hf_config + + monkeypatch.setattr( + "vllm.transformers_utils.config.get_config", + _mock_get_config, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.get_config", + _mock_get_config, + raising=False, + ) + + # Mock transformers' cached_file to avoid downloading model configs + def _mock_cached_file(path_or_repo_id, *args, **kwargs): + import os + import tempfile + + fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json") + if not os.path.exists(fake_config_file): + with open(fake_config_file, "w") as f: + f.write('{"model_type": "qwen2_5_omni"}') + return fake_config_file + + monkeypatch.setattr( + "transformers.utils.hub.cached_file", + _mock_cached_file, + raising=False, + ) + monkeypatch.setattr( + "transformers.utils.hub.cached_files", + lambda path_or_repo_id, filenames, **kwargs: ( + [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None + ), + raising=False, + ) + + +def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config): + """Test that stage configs are auto-loaded when stage_configs_path is None.""" + + def _fake_loader(model: str, base_engine_args=None): + return [ + _FakeStageConfig(fake_stage_config), + _FakeStageConfig(fake_stage_config), + ] + + # Remove modules from cache BEFORE setting mocks + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + # Set up mocks + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + # Mock load_stage_configs_from_model + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + + # Replace OmniStage + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + # Import the module after mocks are set + import vllm_omni.entrypoints.omni as omni_module + + # Patch the imported function and class in the module + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + # Verify: auto-loaded stage_configs and stage_list have consistent count + assert isinstance(omni.stage_configs, list) + assert len(omni.stage_configs) == 2 + assert len(omni.stage_list) == 2 + # Verify: each Stage is _FakeStage instance + for st in omni.stage_list: + assert isinstance(st, _FakeStage) + # Verify: queues are attached + for st in omni.stage_list: + assert st._in_q is not None + assert st._out_q is not None + # Verify: all stages are ready + assert len(omni._stages_ready) == 2 + + +def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): + """Test that generate raises ValueError when sampling_params_list length doesn't match.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + with pytest.raises(ValueError): + omni.generate(prompts=["hi"], sampling_params_list=[]) + + +def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): + """Test multi-stage generation pipeline with queue polling.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["processed_input"] = ["processed-for-stage-1"] + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: manually put results into output queues + # Note: We put results before calling generate, which simulates worker processes + # that have already completed. The polling loop will collect them in stage order. + # Stage 0 output (will be collected first) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "s0"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + # Stage 1 output (will be collected after stage 0 forwards to it) + # Note: In real flow, stage 1 result would appear after stage 0 forwards, + # but for testing we pre-populate it. The polling loop processes stages + # in order, so stage 0 result will be collected first, then forwarded, + # then stage 1 result will be collected. + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1, "text": "s1"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + sampling_params_list = [ + OmniDiffusionSamplingParams(num_inference_steps=1), + OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10), + ] + prompts = ["hi"] + outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) + + # Both stages have final_output=True, so should aggregate two OmniRequestOutput + assert len(outputs) == 2 + # Verify stage outputs are set + assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}] + assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}] + # Verify stage 0 input queue received the task + assert not omni.stage_list[0]._in_q.empty() + # Verify stage 1 received forwarded task (process_engine_inputs was called) + assert omni.stage_list[1].process_engine_inputs([], []) is not None + + +def test_generate_pipeline_with_batch_input(monkeypatch, fake_stage_config): + """Test single-stage generation pipeline with multiple inputs in one batch.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: manually put results into output queues + # Note: We put results before calling generate, which simulates worker processes + # that have already completed. The polling loop will collect them in stage order. + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "s0"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "s0"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + outputs = omni.generate( + prompts=[ + { + "prompt": "hi", + "negative_prompt": "hi", + "multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]}, + }, + { + "prompt": "hi", + "negative_prompt": "hi", + "multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]}, + }, + ], + sampling_params_list=[ + OmniDiffusionSamplingParams(num_inference_steps=1), + OmniDiffusionSamplingParams(num_inference_steps=1), + ], + ) + + assert len(outputs) == 2 + + +def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): + """Test that generate returns empty list when all stages have final_output=False.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + stage_cfg1["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: put results into output queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + outputs = omni.generate( + prompts=["p"], + sampling_params_list=[ + OmniDiffusionSamplingParams(num_inference_steps=1), + OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10), + ], + ) + assert outputs == [] + + +def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config): + """Test that generate uses default sampling params when sampling_params_list is None.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + stage_cfg1["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: put results into output queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + # Use the default sampling params + omni.generate(prompts=["p"], sampling_params_list=None) + + +def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): + """Test that _wait_for_stages_ready handles timeout correctly.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + + # Create a stage that doesn't send stage_ready message + class _FakeStageNoReady(_FakeStage): + def init_stage_worker(self, *args, **kwargs): + # Don't send stage_ready message + self._proc = MagicMock() + self._proc.start = MagicMock() + self._proc.join = MagicMock() + self._proc.is_alive = MagicMock(return_value=False) + self._proc.terminate = MagicMock() + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + # Use very short timeout + omni = Omni(model="any", init_timeout=0.01) + # Verify that no stages are ready + assert len(omni._stages_ready) == 0 + + +def test_generate_handles_error_messages(monkeypatch, fake_stage_config): + """Test that generate handles error messages from stages correctly.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Put error message in output queue + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "error": "test error", + } + ) + # Also put a valid result after error to allow the loop to complete + # (error handling continues the loop, so we need a valid result to finish) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "result"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + # Generate should handle error gracefully (log but continue) + sampling_params_list = [OmniDiffusionSamplingParams(num_inference_steps=1)] + outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list) + # Should return final output (error was logged but didn't stop processing) + assert isinstance(outputs, list) + # Since final_output=True, should have one output + assert len(outputs) == 1 + + +def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config): + """Test that close() sends shutdown signal to all input queues.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Call close + omni.close() + + # Verify shutdown signal (None) was sent to input queue + # Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe) + try: + shutdown_signal = omni.stage_list[0]._in_q.get_nowait() + assert shutdown_signal == SHUTDOWN_TASK + except Empty: + # If queue was already empty or only had stage_ready, that's also acceptable + # The important thing is that close() was called without error + pass + + # Verify stop_stage_worker was called (process should be set) + assert omni.stage_list[0]._proc is not None diff --git a/tests/entrypoints/test_omni_input_preprocessor.py b/tests/entrypoints/test_omni_input_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..77c84f06b3c695ed6bfc6392f6bb7892fdab7fa4 --- /dev/null +++ b/tests/entrypoints/test_omni_input_preprocessor.py @@ -0,0 +1,59 @@ +from vllm_omni.inputs.preprocess import OmniInputPreprocessor + + +def _make_preprocessor(monkeypatch): + preprocessor = object.__new__(OmniInputPreprocessor) + monkeypatch.setattr(preprocessor, "_truncate_inputs", lambda tokens, tokenization_kwargs=None: tokens) + monkeypatch.setattr( + preprocessor, + "_process_multimodal", + lambda *args, **kwargs: {"prompt_token_ids": [1, 2, 3]}, + ) + monkeypatch.setattr(preprocessor, "_tokenize_prompt", lambda prompt_text, tokenization_kwargs=None: [9, 8, 7]) + return preprocessor + + +def test_process_tokens_keeps_additional_information(monkeypatch): + preprocessor = _make_preprocessor(monkeypatch) + parsed = { + "prompt_token_ids": [1, 2, 3], + "prompt_embeds": "embeds", + "additional_information": {"task": ["tts"], "lang": ["auto"]}, + } + + inputs = OmniInputPreprocessor._process_tokens(preprocessor, parsed) + + assert inputs["prompt_token_ids"] == [1, 2, 3] + assert inputs["prompt_embeds"] == "embeds" + assert inputs["additional_information"] == {"task": ["tts"], "lang": ["auto"]} + + +def test_process_text_keeps_additional_information(monkeypatch): + preprocessor = _make_preprocessor(monkeypatch) + parsed = { + "prompt": "hello", + "prompt_embeds": "embeds", + "additional_information": {"speaker": ["alice"]}, + } + + inputs = OmniInputPreprocessor._process_text(preprocessor, parsed) + + assert inputs["prompt_token_ids"] == [9, 8, 7] + assert inputs["prompt_embeds"] == "embeds" + assert inputs["additional_information"] == {"speaker": ["alice"]} + + +def test_process_text_multimodal_skips_empty_payloads(monkeypatch): + preprocessor = _make_preprocessor(monkeypatch) + parsed = { + "prompt": "hello", + "multi_modal_data": {"image": "fake"}, + "prompt_embeds": None, + "additional_information": None, + } + + inputs = OmniInputPreprocessor._process_text(preprocessor, parsed) + + assert inputs["prompt_token_ids"] == [1, 2, 3] + assert "prompt_embeds" not in inputs + assert "additional_information" not in inputs diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..f99c6d8336ca7fe8b5989c835159fee92ee231cb --- /dev/null +++ b/tests/entrypoints/test_omni_llm.py @@ -0,0 +1,997 @@ +import uuid +import warnings +from queue import Empty, Queue +from typing import Any +from unittest.mock import MagicMock + +import pytest +from vllm import SamplingParams + +from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK + +# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPy.*has no __module__ attribute", + category=DeprecationWarning, +) + + +class _FakeEngineArgs(dict): + """Fake engine args that can be used both as object attributes and as **kwargs.""" + + def __init__(self, args_dict: dict[str, Any]): + super().__init__(args_dict) + # Add required attributes if not present + if "model_stage" not in self: + self["model_stage"] = None + if "engine_output_type" not in self: + self["engine_output_type"] = None + # Also set as attributes for object-style access + for key, value in self.items(): + setattr(self, key, value) + + +class _FakeStageConfig: + """Fake stage config object that mimics the real stage config structure.""" + + def __init__(self, config_dict: dict[str, Any]): + # engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs) + engine_args_dict = config_dict.get("engine_args", {}) + self.engine_args = _FakeEngineArgs(engine_args_dict) + self.final_output = config_dict.get("final_output", False) + self.final_output_type = config_dict.get("final_output_type", None) + self.stage_id = config_dict.get("stage_id", 0) + # Store original dict for reference + self._config_dict = config_dict + + +class _FakeQueue: + """Fake queue using standard library Queue to replace mp.Queue.""" + + def __init__(self, maxsize=0): + self._queue = Queue(maxsize=maxsize) + + def put(self, item): + self._queue.put(item) + + def put_nowait(self, item): + self._queue.put_nowait(item) + + def get(self): + return self._queue.get() + + def get_nowait(self): + return self._queue.get_nowait() + + def empty(self): + return self._queue.empty() + + +class _FakeStage: + """Lightweight Stage stub for multi-process pipeline version with queue support.""" + + def __init__(self, config, stage_init_timeout: int = 300): + # Handle both dict and object configs + if isinstance(config, dict): + config = _FakeStageConfig(config) + self.config = config + self.stage_config = config + self.engine = None + self.engine_outputs = None + # Set attributes that OmniStage expects + self.stage_id = getattr(config, "stage_id", 0) + self.engine_args = config.engine_args + self.model_stage = getattr(config.engine_args, "model_stage", None) + self.stage_type = "llm" + # set default sampling params + self.default_sampling_params = SamplingParams(temperature=1.0) + # Allow configuring final_output and final_output_type + self.final_output = config.final_output if hasattr(config, "final_output") else False + self.final_output_type = getattr(config, "final_output_type", None) + # Configurable processing logic, default returns placeholder + processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"]) + self._processed_input = processed_input + # Queue references (set by attach_queues) + self._in_q = None + self._out_q = None + self._proc = None # Mock process reference + self._stage_init_timeout = max(0, int(stage_init_timeout)) + + def attach_queues(self, in_q, out_q): + """Attach input and output queues.""" + self._in_q = in_q + self._out_q = out_q + + def init_stage_worker( + self, + model: str, + *, + is_async: bool = False, + shm_threshold_bytes: int = 65536, + ctx=None, + batch_timeout: int = 10, + **kwargs, + ): + """Mock init_stage_worker: don't start real process, just send stage_ready message.""" + # Create a mock process object + self._proc = MagicMock() + self._proc.start = MagicMock() + self._proc.join = MagicMock() + self._proc.is_alive = MagicMock(return_value=False) + self._proc.terminate = MagicMock() + # Send stage_ready message to output queue + if self._out_q is not None: + try: + self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id}) + except Exception: + pass + + def stop_stage_worker(self): + """Mock stop_stage_worker: clean up queue references.""" + if self._in_q is not None: + try: + self._in_q.put_nowait(SHUTDOWN_TASK) + except Exception: + pass + + def submit(self, payload: dict[str, Any]): + """Submit task to input queue.""" + if self._in_q is not None: + self._in_q.put(payload) + + def try_collect(self) -> Any: + """Non-blocking collect from output queue.""" + if self._out_q is None: + return None + try: + return self._out_q.get_nowait() + except Empty: + return None + + def set_engine_outputs(self, outputs): + """Set engine outputs for the stage.""" + self.engine_outputs = outputs + + def process_engine_inputs(self, stage_list, prompts): + """Process engine inputs: return preset processed result.""" + return self._processed_input + + +class _FakeEngine: + """Lightweight Engine stub: provides generate iterator output.""" + + def __init__(self, outputs: list[Any]): + self._outputs = outputs + + def generate(self, prompts, sampling_params): + # Record the most recent prompts for outer assertions + self._last_prompts = prompts + # Simplified: return preset list at once, ensuring iterability + yield from self._outputs + + +@pytest.fixture +def fake_stage_config(): + return { + # Don't include 'model' in engine_args since it's passed separately + "engine_args": {}, + "final_output": True, + "final_output_type": "text", + # Second stage will use processed_input to verify the chain + "processed_input": ["processed-by-stage"], + } + + +def _setup_engine_mocks(monkeypatch): + """Helper function to set up common engine mocks.""" + fake_engine = MagicMock() + # Add necessary attributes to fake_engine + fake_engine.tokenizer = MagicMock() + fake_engine.log_stats = False + fake_engine.vllm_config = MagicMock() + fake_engine.vllm_config.model_config = MagicMock() + fake_engine.vllm_config.model_config.io_processor_plugin = None + fake_engine.get_supported_tasks = MagicMock(return_value=[]) + fake_engine.model_config = MagicMock() + fake_engine.model_config.io_processor_plugin = None + # Add registry with resolve_model_cls method + fake_registry = MagicMock() + fake_registry.resolve_model_cls = MagicMock(return_value=(MagicMock(), "test_arch")) + fake_engine.model_config.registry = fake_registry + fake_engine.vllm_config.model_config.registry = fake_registry + + monkeypatch.setattr( + "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args", + lambda **kw: fake_engine, + raising=False, + ) + + # Mock model_config.registry.resolve_model_cls to return a tuple + # Use a real class instead of MagicMock to avoid inspect.getsource issues + class FakeModelClass: + pass + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils.get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils._get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + + # Mock try_create_mm_pooling_model_cls to return the class as-is + monkeypatch.setattr( + "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls", + lambda model_cls: model_cls, + raising=False, + ) + + # Mock _enable_processor_cache to return False + monkeypatch.setattr( + "vllm.multimodal.cache._enable_processor_cache", + lambda model_config, mm_registry: False, + raising=False, + ) + + # Mock get_io_processor to return None + monkeypatch.setattr( + "vllm.plugins.io_processors.get_io_processor", + lambda vllm_config, io_processor_plugin: None, + raising=False, + ) + + +def _setup_multiprocessing_mocks(monkeypatch): + """Helper function to set up multiprocessing mocks.""" + import multiprocessing as mp + + # Mock Process + fake_process_class = MagicMock() + fake_process_instance = MagicMock() + fake_process_instance.start = MagicMock() + fake_process_instance.join = MagicMock() + fake_process_instance.is_alive = MagicMock(return_value=False) + fake_process_instance.terminate = MagicMock() + fake_process_class.return_value = fake_process_instance + + # Mock get_context to return a context with Queue that returns _FakeQueue + fake_ctx = MagicMock() + fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize) + fake_ctx.Process = fake_process_class + + def _mock_get_context(method): + return fake_ctx + + monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False) + monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) + + +def _setup_ipc_mocks(monkeypatch): + """Helper function to set up IPC function mocks.""" + + # Mock _encode: simple serialization + def _fake_encode(obj, threshold, obj_key, shm_key): + return {obj_key: obj} + + # Mock _load: extract object from result + def _fake_load(result, obj_key, shm_key): + return result.get(obj_key) + + # Mock _set: calculate serialization size + def _fake_set(obj): + return str(obj).encode() + + monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False) + + +def _setup_log_mocks(monkeypatch): + """Helper function to set up logging and stats mocks.""" + # Mock OrchestratorMetrics to be a simple class that doesn't require file operations + + class _FakeOrchestratorMetrics: + def __init__(self, num_stages, enable_stats, wall_start_ts): + self.num_stages = num_stages + self.enable_stats = enable_stats + self.stage_first_ts = [None] * num_stages + self.stage_last_ts = [None] * num_stages + self.e2e_done = set() + + def on_stage_metrics(self, stage_id, req_id, metrics): + pass + + def on_finalize_request(self, stage_id, req_id, start_ts): + self.e2e_done.add(req_id) + + def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm): + pass + + def build_and_log_summary(self, final_stage_id): + return "Fake summary" + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni.OrchestratorMetrics", + _FakeOrchestratorMetrics, + raising=False, + ) + + +@pytest.fixture(autouse=True) +def mock_get_config(monkeypatch): + """Auto-mock get_config and related model loading functions to avoid model path validation.""" + # CRITICAL: Mock tokenizer-related imports FIRST, before any module imports + # This prevents ImportError when async_omni is imported (which happens via omni_stage) + import sys + + fake_tokenizer = MagicMock() + fake_tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + fake_tokenizer.decode = MagicMock(return_value="test") + + # Mock init_tokenizer_from_configs (used in async_omni) + def _mock_init_tokenizer_from_configs(model_config=None, **kwargs): + return fake_tokenizer + + # Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer) + # This works if the module hasn't been imported yet + monkeypatch.setattr( + "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + + # Strategy 2: If the module is already in sys.modules, patch it directly + tokenizer_module_path = "vllm.transformers_utils.tokenizer" + if tokenizer_module_path in sys.modules: + tokenizer_module = sys.modules[tokenizer_module_path] + setattr(tokenizer_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + # CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni + + # This is because async_omni imports processor.py, which imports this function at module level + # Mock length_from_prompt_token_ids_or_embeds (used in processor.py) + def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None): + # Return a reasonable default length + if prompt_token_ids is not None: + if isinstance(prompt_token_ids, list): + return len(prompt_token_ids) + elif hasattr(prompt_token_ids, "shape"): + return prompt_token_ids.shape[-1] if len(prompt_token_ids.shape) > 0 else 1 + if prompt_embeds is not None: + if hasattr(prompt_embeds, "shape"): + return prompt_embeds.shape[-2] if len(prompt_embeds.shape) > 1 else 1 + return 10 # Default length + + # Mock in vllm.utils + monkeypatch.setattr( + "vllm.utils.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + # Also mock in processor module if it's imported + monkeypatch.setattr( + "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + # If processor module is already imported, patch it directly + processor_module_path = "vllm_omni.engine.input_processor" + if processor_module_path in sys.modules: + processor_module = sys.modules[processor_module_path] + setattr( + processor_module, "length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds + ) + + # Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked + # This prevents ImportError when async_omni imports processor.py + monkeypatch.setattr( + "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + + # Strategy 4: If async_omni is already imported, patch it directly + async_omni_path = "vllm_omni.entrypoints.async_omni" + if async_omni_path in sys.modules: + async_omni_module = sys.modules[async_omni_path] + setattr(async_omni_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + # Now mock get_config and other functions + fake_hf_config = MagicMock() + fake_hf_config.model_type = "qwen2_5_omni" + + def _mock_get_config(model, **kwargs): + return fake_hf_config + + monkeypatch.setattr( + "vllm.transformers_utils.config.get_config", + _mock_get_config, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.get_config", + _mock_get_config, + raising=False, + ) + + # Mock transformers' cached_file to avoid downloading model configs + def _mock_cached_file(path_or_repo_id, *args, **kwargs): + import os + import tempfile + + fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json") + if not os.path.exists(fake_config_file): + with open(fake_config_file, "w") as f: + f.write('{"model_type": "qwen2_5_omni"}') + return fake_config_file + + monkeypatch.setattr( + "transformers.utils.hub.cached_file", + _mock_cached_file, + raising=False, + ) + monkeypatch.setattr( + "transformers.utils.hub.cached_files", + lambda path_or_repo_id, filenames, **kwargs: ( + [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None + ), + raising=False, + ) + + +def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config): + """Test that stage configs are auto-loaded when stage_configs_path is None.""" + + def _fake_loader(model: str, base_engine_args=None): + return [ + _FakeStageConfig(fake_stage_config), + _FakeStageConfig(fake_stage_config), + ] + + # Remove modules from cache BEFORE setting mocks + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + # Set up mocks + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + # Mock load_stage_configs_from_model + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + + # Replace OmniStage + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + # Import the module after mocks are set + import vllm_omni.entrypoints.omni as omni_module + + # Patch the imported function and class in the module + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + # Verify: auto-loaded stage_configs and stage_list have consistent count + assert isinstance(omni.stage_configs, list) + assert len(omni.stage_configs) == 2 + assert len(omni.stage_list) == 2 + # Verify: each Stage is _FakeStage instance + for st in omni.stage_list: + assert isinstance(st, _FakeStage) + # Verify: queues are attached + for st in omni.stage_list: + assert st._in_q is not None + assert st._out_q is not None + # Verify: all stages are ready + assert len(omni._stages_ready) == 2 + + +def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): + """Test that generate raises ValueError when sampling_params_list length doesn't match.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + with pytest.raises(ValueError): + omni.generate(prompts=["hi"], sampling_params_list=[]) + + +def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): + """Test multi-stage generation pipeline with queue polling.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["processed_input"] = ["processed-for-stage-1"] + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: manually put results into output queues + # Note: We put results before calling generate, which simulates worker processes + # that have already completed. The polling loop will collect them in stage order. + # Stage 0 output (will be collected first) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "s0"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + # Stage 1 output (will be collected after stage 0 forwards to it) + # Note: In real flow, stage 1 result would appear after stage 0 forwards, + # but for testing we pre-populate it. The polling loop processes stages + # in order, so stage 0 result will be collected first, then forwarded, + # then stage 1 result will be collected. + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1, "text": "s1"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + sampling_params_list = [ + SamplingParams(temperature=0.7), + SamplingParams(temperature=0.8), + ] + prompts = ["hi"] + outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) + + # Both stages have final_output=True, so should aggregate two OmniRequestOutput + assert len(outputs) == 2 + # Verify stage outputs are set + assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}] + assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}] + # Verify stage 0 input queue received the task + assert not omni.stage_list[0]._in_q.empty() + # Verify stage 1 received forwarded task (process_engine_inputs was called) + assert omni.stage_list[1].process_engine_inputs([], []) is not None + + +def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): + """Test that generate returns empty list when all stages have final_output=False.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + stage_cfg1["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: put results into output queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + outputs = omni.generate( + prompts=["p"], + sampling_params_list=[ + SamplingParams(temperature=0.7), + SamplingParams(temperature=0.8), + ], + ) + assert outputs == [] + + +def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config): + """Test that generate uses default sampling params when sampling_params_list is None.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + stage_cfg1["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: put results into output queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + # Use the default sampling params + omni.generate(prompts=["p"], sampling_params_list=None) + + +def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): + """Test that _wait_for_stages_ready handles timeout correctly.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + + # Create a stage that doesn't send stage_ready message + class _FakeStageNoReady(_FakeStage): + def init_stage_worker(self, *args, **kwargs): + # Don't send stage_ready message + self._proc = MagicMock() + self._proc.start = MagicMock() + self._proc.join = MagicMock() + self._proc.is_alive = MagicMock(return_value=False) + self._proc.terminate = MagicMock() + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + # Use very short timeout + omni = Omni(model="any", init_timeout=0.01) + # Verify that no stages are ready + assert len(omni._stages_ready) == 0 + + +def test_generate_handles_error_messages(monkeypatch, fake_stage_config): + """Test that generate handles error messages from stages correctly.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_<uuid>" + expected_request_id = f"0_{test_uuid}" + + # Put error message in output queue + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "error": "test error", + } + ) + # Also put a valid result after error to allow the loop to complete + # (error handling continues the loop, so we need a valid result to finish) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "result"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + # Generate should handle error gracefully (log but continue) + sampling_params_list = [SamplingParams(temperature=0.7)] + outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list) + # Should return final output (error was logged but didn't stop processing) + assert isinstance(outputs, list) + # Since final_output=True, should have one output + assert len(outputs) == 1 + + +def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config): + """Test that close() sends shutdown signal to all input queues.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Call close + omni.close() + + # Verify shutdown signal (None) was sent to input queue + # Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe) + try: + shutdown_signal = omni.stage_list[0]._in_q.get_nowait() + assert shutdown_signal == SHUTDOWN_TASK + except Empty: + # If queue was already empty or only had stage_ready, that's also acceptable + # The important thing is that close() was called without error + pass + + # Verify stop_stage_worker was called (process should be set) + assert omni.stage_list[0]._proc is not None diff --git a/tests/entrypoints/test_omni_new_request_data.py b/tests/entrypoints/test_omni_new_request_data.py new file mode 100644 index 0000000000000000000000000000000000000000..776509d5bba597e8152fbb5db79cf71ef51e06a7 --- /dev/null +++ b/tests/entrypoints/test_omni_new_request_data.py @@ -0,0 +1,51 @@ +from types import SimpleNamespace + +import torch + +from vllm_omni.core.sched.output import OmniNewRequestData + + +def test_omni_new_request_data_copies_payloads(): + prompt_embeds = torch.randn(2, 3) + additional_information = { + "speaker": ["test"], + "codes": torch.tensor([1, 2], dtype=torch.int64), + } + request = SimpleNamespace( + request_id="req-1", + external_req_id="ext-1", + prompt_token_ids=[101, 102], + mm_features=None, + sampling_params=None, + pooling_params=None, + num_computed_tokens=0, + lora_request=None, + prompt_embeds=prompt_embeds, + additional_information=additional_information, + ) + + data = OmniNewRequestData.from_request(request, ([0, 1],), prefill_token_ids=[101, 102]) + + assert data.prompt_embeds is prompt_embeds + assert data.additional_information is additional_information + assert data.prefill_token_ids == [101, 102] + + +def test_omni_new_request_data_allows_missing_payloads(): + request = SimpleNamespace( + request_id="req-2", + external_req_id="ext-2", + prompt_token_ids=[201, 202], + mm_features=None, + sampling_params=None, + pooling_params=None, + num_computed_tokens=0, + lora_request=None, + prompt_embeds=None, + additional_information=None, + ) + + data = OmniNewRequestData.from_request(request, ([0],), prefill_token_ids=None) + + assert data.prompt_embeds is None + assert data.additional_information is None diff --git a/tests/entrypoints/test_omni_stage_diffusion_config.py b/tests/entrypoints/test_omni_stage_diffusion_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5fe04cbbd88d361d2c3683daf43f4e462eb92151 --- /dev/null +++ b/tests/entrypoints/test_omni_stage_diffusion_config.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.entrypoints.omni_stage import _build_od_config + + +def test_build_od_config_includes_diffusion_fields(): + engine_args = { + "cache_backend": "cache_dit", + "cache_config": {"Fn_compute_blocks": 2}, + "vae_use_slicing": True, + } + od_config = _build_od_config(engine_args, model="dummy-model") + + assert od_config["model"] == "dummy-model" + assert od_config["cache_backend"] == "cache_dit" + assert od_config["cache_config"]["Fn_compute_blocks"] == 2 + assert od_config["vae_use_slicing"] is True + + +def test_build_od_config_respects_explicit_config(): + engine_args = { + "od_config": {"cache_backend": "tea_cache"}, + "cache_backend": "cache_dit", + } + od_config = _build_od_config(engine_args, model="dummy-model") + assert od_config == {"cache_backend": "tea_cache"} diff --git a/tests/entrypoints/test_stage_utils.py b/tests/entrypoints/test_stage_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac503639be627b1e4068b3476876a9065c2f9258 --- /dev/null +++ b/tests/entrypoints/test_stage_utils.py @@ -0,0 +1,108 @@ +import os +import sys +from unittest.mock import MagicMock + +import pytest + +from vllm_omni.entrypoints.stage_utils import set_stage_devices + + +def _make_dummy_torch(call_log): + class _Props: + def __init__(self, total): + self.total_memory = total + + class _Cuda: + @staticmethod + def is_available(): + return True + + @staticmethod + def set_device(idx): + call_log.append(idx) + + @staticmethod + def device_count(): + return 2 + + @staticmethod + def get_device_properties(idx): + return _Props(total=16000) + + @staticmethod + def mem_get_info(idx): + return (8000, 16000) + + @staticmethod + def get_device_name(idx): + return f"gpu-{idx}" + + class _Torch: + cuda = _Cuda + + return _Torch + + +def _make_mock_platform(device_type: str = "cuda", env_var: str = "CUDA_VISIBLE_DEVICES"): + """Create a mock platform for testing.""" + mock_platform = MagicMock() + mock_platform.device_type = device_type + mock_platform.device_control_env_var = env_var + return mock_platform + + +@pytest.mark.usefixtures("clean_gpu_memory_between_tests") +def test_set_stage_devices_respects_logical_ids(monkeypatch): + # Preserve an existing logical mapping and ensure devices "0,1" map through it. + monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "6,7") + call_log: list[int] = [] + dummy_torch = _make_dummy_torch(call_log) + monkeypatch.setitem(sys.modules, "torch", dummy_torch) + + # Mock the platform at the source module where it's defined + mock_platform = _make_mock_platform(device_type="cuda", env_var="CUDA_VISIBLE_DEVICES") + monkeypatch.setattr( + "vllm_omni.platforms.current_omni_platform", + mock_platform, + ) + + set_stage_devices(stage_id=0, devices="0,1") + + assert os.environ["CUDA_VISIBLE_DEVICES"] == "6,7" + + +@pytest.mark.usefixtures("clean_gpu_memory_between_tests") +def test_set_stage_devices_npu_platform(monkeypatch): + """Test that set_stage_devices works correctly for NPU platform.""" + monkeypatch.setenv("ASCEND_RT_VISIBLE_DEVICES", "4,5") + call_log: list[int] = [] + + # Create NPU mock torch + class _Npu: + @staticmethod + def is_available(): + return True + + @staticmethod + def set_device(idx): + call_log.append(idx) + + @staticmethod + def device_count(): + return 2 + + class _NpuTorch: + npu = _Npu + + monkeypatch.setitem(sys.modules, "torch", _NpuTorch) + + # Mock NPU platform at the source module where it's defined + mock_platform = _make_mock_platform(device_type="npu", env_var="ASCEND_RT_VISIBLE_DEVICES") + monkeypatch.setattr( + "vllm_omni.platforms.current_omni_platform", + mock_platform, + ) + + set_stage_devices(stage_id=0, devices="0,1") + + assert os.environ["ASCEND_RT_VISIBLE_DEVICES"] == "4,5" diff --git a/tests/model_executor/models/qwen2_5_omni/test_audio_length.py b/tests/model_executor/models/qwen2_5_omni/test_audio_length.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5f098172c670cfa712a305b3fef4dfb01c974b --- /dev/null +++ b/tests/model_executor/models/qwen2_5_omni/test_audio_length.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + + +def test_resolve_max_mel_frames_default(): + from vllm_omni.model_executor.models.qwen2_5_omni.audio_length import resolve_max_mel_frames + + assert resolve_max_mel_frames(None, default=30000) == 30000 + assert resolve_max_mel_frames(None, default=6000) == 6000 + + +def test_resolve_max_mel_frames_explicit(): + from vllm_omni.model_executor.models.qwen2_5_omni.audio_length import resolve_max_mel_frames + + # Explicit argument always wins over default + assert resolve_max_mel_frames(123, default=30000) == 123 + assert resolve_max_mel_frames(6000, default=30000) == 6000 + assert resolve_max_mel_frames(0, default=30000) == 0 + + +@pytest.mark.parametrize("repeats", [2, 4]) +@pytest.mark.parametrize("code_len", [0, 1, 32768]) +@pytest.mark.parametrize("max_mel_frames", [None, -1, 0, 1, 6000, 30000]) +def test_cap_and_align_mel_length_no_mismatch(repeats, code_len, max_mel_frames): + """Guard that any max_mel_frames yields a mel length aligned to repeats, and + consistent with the truncated code length (prevents concat mismatch). + """ + from vllm_omni.model_executor.models.qwen2_5_omni.audio_length import cap_and_align_mel_length + + target_code_len, target_mel_len = cap_and_align_mel_length( + code_len=code_len, + repeats=repeats, + max_mel_frames=max_mel_frames, + ) + + assert isinstance(target_code_len, int) + assert isinstance(target_mel_len, int) + + if code_len == 0: + assert target_code_len == 0 + assert target_mel_len == 0 + return + + assert target_code_len >= 1 + assert target_mel_len >= repeats + assert target_mel_len % repeats == 0 + assert target_mel_len == target_code_len * repeats + assert target_code_len <= code_len + + if max_mel_frames is not None and int(max_mel_frames) > 0 and int(max_mel_frames) >= repeats: + assert target_mel_len <= int(max_mel_frames) diff --git a/tests/test_outputs.py b/tests/test_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebc0e43cb60c58e6678a1b0304ec015e56d8dd8 --- /dev/null +++ b/tests/test_outputs.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for OmniRequestOutput class.""" + +from unittest.mock import Mock + +from PIL import Image + +from vllm_omni.outputs import OmniRequestOutput + + +class TestOmniRequestOutput: + """Tests for OmniRequestOutput class.""" + + def test_from_diffusion(self): + """Test creating output from diffusion model.""" + images = [Image.new("RGB", (64, 64), color="red")] + output = OmniRequestOutput.from_diffusion( + request_id="test-123", + images=images, + prompt="a cat", + metrics={"steps": 50}, + ) + assert output.request_id == "test-123" + assert output.images == images + assert output.prompt == "a cat" + assert output.metrics == {"steps": 50} + assert output.is_diffusion_output + assert output.num_images == 1 + + def test_from_pipeline(self): + """Test creating output from pipeline stage.""" + mock_request_output = Mock() + mock_request_output.request_id = "pipeline-123" + mock_request_output.prompt_token_ids = [1, 2, 3] + mock_request_output.outputs = [Mock()] + mock_request_output.encoder_prompt_token_ids = None + mock_request_output.prompt_logprobs = None + mock_request_output.num_cached_tokens = 10 + mock_request_output.kv_transfer_params = None + mock_request_output.multimodal_output = {"image": Mock()} + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="text", + request_output=mock_request_output, + ) + + assert output.request_id == "pipeline-123" + assert output.stage_id == 0 + assert output.final_output_type == "text" + assert output.is_pipeline_output + + def test_prompt_token_ids_property(self): + """Test prompt_token_ids property for streaming compatibility.""" + mock_request_output = Mock() + mock_request_output.prompt_token_ids = [1, 2, 3, 4, 5] + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="text", + request_output=mock_request_output, + ) + + assert output.prompt_token_ids == [1, 2, 3, 4, 5] + + def test_prompt_token_ids_none_when_no_request_output(self): + """Test prompt_token_ids returns None when no request_output.""" + output = OmniRequestOutput.from_diffusion( + request_id="test-123", + images=[], + prompt="a cat", + ) + assert output.prompt_token_ids is None + + def test_outputs_property(self): + """Test outputs property for chat completion compatibility.""" + mock_output = Mock() + mock_request_output = Mock() + mock_request_output.outputs = [mock_output] + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="text", + request_output=mock_request_output, + ) + + assert output.outputs == [mock_output] + + def test_outputs_empty_when_no_request_output(self): + """Test outputs returns empty list when no request_output.""" + output = OmniRequestOutput.from_diffusion( + request_id="test-123", + images=[], + prompt="a cat", + ) + assert output.outputs == [] + + def test_encoder_prompt_token_ids_property(self): + """Test encoder_prompt_token_ids property.""" + mock_request_output = Mock() + mock_request_output.encoder_prompt_token_ids = [10, 20, 30] + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="text", + request_output=mock_request_output, + ) + + assert output.encoder_prompt_token_ids == [10, 20, 30] + + def test_num_cached_tokens_property(self): + """Test num_cached_tokens property.""" + mock_request_output = Mock() + mock_request_output.num_cached_tokens = 42 + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="text", + request_output=mock_request_output, + ) + + assert output.num_cached_tokens == 42 + + def test_multimodal_output_property(self): + """Test multimodal_output property.""" + mock_request_output = Mock() + mock_audio = Mock() + expected_output = {"audio": mock_audio} + mock_request_output.multimodal_output = expected_output + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="audio", + request_output=mock_request_output, + ) + + assert output.multimodal_output is expected_output + + def test_multimodal_output_prefers_completion_output(self): + """Test multimodal_output prefers completion output payloads.""" + completion_output = Mock() + completion_mm = {"audio": Mock()} + completion_output.multimodal_output = completion_mm + + mock_request_output = Mock() + mock_request_output.outputs = [completion_output] + mock_request_output.multimodal_output = {"audio": Mock()} + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="audio", + request_output=mock_request_output, + ) + + assert output.multimodal_output is completion_mm + + def test_to_dict_diffusion(self): + """Test to_dict for diffusion output.""" + output = OmniRequestOutput.from_diffusion( + request_id="test-123", + images=[Image.new("RGB", (64, 64), color="red")], + prompt="a cat", + metrics={"steps": 50}, + ) + result = output.to_dict() + + assert result["request_id"] == "test-123" + assert result["finished"] is True + assert result["final_output_type"] == "image" + assert result["num_images"] == 1 + assert result["prompt"] == "a cat" + + def test_to_dict_pipeline(self): + """Test to_dict for pipeline output.""" + mock_request_output = Mock() + mock_request_output.request_id = "pipeline-123" + + output = OmniRequestOutput.from_pipeline( + stage_id=0, + final_output_type="text", + request_output=mock_request_output, + ) + result = output.to_dict() + + assert result["request_id"] == "pipeline-123" + assert result["finished"] is True + assert result["final_output_type"] == "text" + assert result["stage_id"] == 0 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c513a4d360b9b8d35d18b7b23d5e5cf1c05052 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,555 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Some functions are copied from vllm/tests/utils.py +import functools +import os +import signal +import subprocess +import sys +import tempfile +import threading +import time +from collections.abc import Callable +from contextlib import ExitStack, contextmanager, suppress +from typing import Any, Literal + +import cloudpickle +import pytest +import torch +from typing_extensions import ParamSpec +from vllm.platforms import current_platform +from vllm.utils.torch_utils import cuda_device_count_stateless + +_P = ParamSpec("_P") + +if current_platform.is_rocm(): + from amdsmi import ( + amdsmi_get_gpu_vram_usage, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + ) + + @contextmanager + def _nvml(): + try: + amdsmi_init() + yield + finally: + amdsmi_shut_down() +elif current_platform.is_cuda(): + from vllm.third_party.pynvml import ( + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlInit, + nvmlShutdown, + ) + + @contextmanager + def _nvml(): + try: + nvmlInit() + yield + finally: + nvmlShutdown() +else: + + @contextmanager + def _nvml(): + yield + + +def get_physical_device_indices(devices): + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if visible_devices is None: + return devices + + visible_indices = [int(x) for x in visible_devices.split(",")] + index_mapping = {i: physical for i, physical in enumerate(visible_indices)} + return [index_mapping[i] for i in devices if i in index_mapping] + + +@_nvml() +def wait_for_gpu_memory_to_clear( + *, + devices: list[int], + threshold_bytes: int | None = None, + threshold_ratio: float | None = None, + timeout_s: float = 120, +) -> None: + import gc + + assert threshold_bytes is not None or threshold_ratio is not None + # Use nvml instead of pytorch to reduce measurement error from torch cuda + # context. + devices = get_physical_device_indices(devices) + start_time = time.time() + + # Print waiting start information + device_list = ", ".join(str(d) for d in devices) + if threshold_bytes is not None: + threshold_str = f"{threshold_bytes / 2**30:.2f} GiB" + condition_str = f"Memory usage ≤ {threshold_str}" + else: + threshold_percent = threshold_ratio * 100 + threshold_str = f"{threshold_percent:.1f}%" + condition_str = f"Memory usage ratio ≤ {threshold_str}" + + print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}") + + # Define the is_free function based on threshold type + if threshold_bytes is not None: + + def is_free(used, total): + return used <= threshold_bytes / 2**30 + else: + + def is_free(used, total): + return used / total <= threshold_ratio + + while True: + output: dict[int, str] = {} + output_raw: dict[int, tuple[float, float]] = {} + for device in devices: + if current_platform.is_rocm(): + dev_handle = amdsmi_get_processor_handles()[device] + mem_info = amdsmi_get_gpu_vram_usage(dev_handle) + gb_used = mem_info["vram_used"] / 2**10 + gb_total = mem_info["vram_total"] / 2**10 + else: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + gb_total = mem_info.total / 2**30 + output_raw[device] = (gb_used, gb_total) + # Format to more readable form + usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0 + output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)" + + # Optimized GPU memory status print + print("[GPU Memory Status] Current usage:") + for device_id, mem_info in output.items(): + print(f" GPU {device_id}: {mem_info}") + + # Calculate waiting duration + dur_s = time.time() - start_time + elapsed_minutes = dur_s / 60 + + # Check if all devices meet the condition + if all(is_free(used, total) for used, total in output_raw.values()): + # Optimized completion message + print(f"[GPU Memory Freed] Devices {device_list} meet memory condition") + print(f" Condition: {condition_str}") + print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)") + print(" Final status:") + for device_id, mem_info in output.items(): + print(f" GPU {device_id}: {mem_info}") + break + + # Check timeout + if dur_s >= timeout_s: + raise ValueError( + f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n" + f"Condition: {condition_str}\n" + f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices) + ) + + # Add waiting hint (optional) + if dur_s > 10 and int(dur_s) % 10 == 0: # Show hint every 10 seconds + print(f"Waiting... Already waited {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)") + + gc.collect() + torch.cuda.empty_cache() + + time.sleep(5) + + +def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to fork a new process for each test function. + See https://github.com/vllm-project/vllm/issues/7053 for more details. + """ + + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Make the process the leader of its own process group + # to avoid sending SIGTERM to the parent process + os.setpgrp() + from _pytest.outcomes import Skipped + + # Create a unique temporary file to store exception info from child + # process. Use test function name and process ID to avoid collisions. + with ( + tempfile.NamedTemporaryFile( + delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc" + ) as exc_file, + ExitStack() as delete_after, + ): + exc_file_path = exc_file.name + delete_after.callback(os.remove, exc_file_path) + + pid = os.fork() + print(f"Fork a new process to run a test {pid}") + if pid == 0: + # Parent process responsible for deleting, don't delete + # in child. + delete_after.pop_all() + try: + func(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception as e: + import traceback + + tb_string = traceback.format_exc() + + # Try to serialize the exception object first + exc_to_serialize: dict[str, Any] + try: + # First, try to pickle the actual exception with + # its traceback. + exc_to_serialize = {"pickled_exception": e} + # Test if it can be pickled + cloudpickle.dumps(exc_to_serialize) + except (Exception, KeyboardInterrupt): + # Fall back to string-based approach. + exc_to_serialize = { + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": tb_string, + } + try: + with open(exc_file_path, "wb") as f: + cloudpickle.dump(exc_to_serialize, f) + except Exception: + # Fallback: just print the traceback. + print(tb_string) + os._exit(1) + else: + os._exit(0) + else: + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + if _exitcode != 0: + # Try to read the exception from the child process + exc_info = {} + if os.path.exists(exc_file_path): + with suppress(Exception), open(exc_file_path, "rb") as f: + exc_info = cloudpickle.load(f) + + if (original_exception := exc_info.get("pickled_exception")) is not None: + # Re-raise the actual exception object if it was + # successfully pickled. + assert isinstance(original_exception, Exception) + raise original_exception + + if (original_tb := exc_info.get("traceback")) is not None: + # Use string-based traceback for fallback case + raise AssertionError( + f"Test {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode}):\n{original_tb}" + ) from None + + # Fallback to the original generic error + raise AssertionError( + f"function {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode})" + ) from None + + return wrapper + + +def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to spawn a new process for each test function.""" + + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Check if we're already in a subprocess + if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": + # If we are, just run the function directly + return f(*args, **kwargs) + + import torch.multiprocessing as mp + + with suppress(RuntimeError): + mp.set_start_method("spawn") + + # Get the module + module_name = f.__module__ + + # Create a process with environment variable set + env = os.environ.copy() + env["RUNNING_IN_SUBPROCESS"] = "1" + + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "new_process.tmp") + + # `cloudpickle` allows pickling complex functions directly + input_bytes = cloudpickle.dumps((f, output_filepath)) + + cmd = [sys.executable, "-m", f"{module_name}"] + + returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e + + return wrapper + + +def create_new_process_for_each_test( + method: Literal["spawn", "fork"] | None = None, +) -> Callable[[Callable[_P, None]], Callable[_P, None]]: + """Creates a decorator that runs each test function in a new process. + + Args: + method: The process creation method. Can be either "spawn" or "fork". + If not specified, it defaults to "spawn" on ROCm and XPU + platforms and "fork" otherwise. + + Returns: + A decorator to run test functions in separate processes. + """ + if method is None: + # TODO: Spawn is not working correctly on ROCm + # The test content will not run and tests passed immediately. + # For now, using `fork` for ROCm as it can run with `fork` + # and tests are running correctly. + use_spawn = current_platform.is_xpu() + method = "spawn" if use_spawn else "fork" + + assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" + + if method == "fork": + return fork_new_process_for_each_test + + return spawn_new_process_for_each_test + + +def cuda_marks(*, res: str, num_cards: int): + """ + Get a collection of pytest marks to apply for `@cuda_test`. + + Args: + res: Resource type, e.g., "L4" or "H100". + num_cards: Number of GPU cards required. + + Returns: + List of pytest marks to apply. + """ + test_platform_detail = pytest.mark.cuda + + if res == "L4": + test_resource = pytest.mark.L4 + elif res == "H100": + test_resource = pytest.mark.H100 + else: + raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100") + + marks = [test_resource, test_platform_detail] + + if num_cards == 1: + return marks + else: + test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards) + test_skipif = pytest.mark.skipif_cuda( + cuda_device_count_stateless() < num_cards, + reason=f"Need at least {num_cards} CUDA GPUs to run the test.", + ) + return marks + [test_distributed, test_skipif] + + +def rocm_marks(*, res: str, num_cards: int): + """ + Get a collection of pytest marks to apply for `@rocm_test`. + + Args: + res: Resource type, e.g., "MI325". + num_cards: Number of GPU cards required. + + Returns: + List of pytest marks to apply. + """ + test_platform_detail = pytest.mark.rocm + + if res == "MI325": + test_resource = pytest.mark.MI325 + else: + raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325") + + marks = [test_resource, test_platform_detail] + + if num_cards == 1: + return marks + else: + test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards) + # TODO: add ROCm support for `skipif_rocm` marker + return marks + [test_distributed] + + +def gpu_marks(*, res: str, num_cards: int): + """ + Get a collection of pytest marks to apply for `@gpu_test`. + Platform is automatically determined based on resource type. + + Args: + res: Resource type, e.g., "L4", "H100" for CUDA, or "MI325" for ROCm. + num_cards: Number of GPU cards required. + + Returns: + List of pytest marks to apply. + """ + test_platform = pytest.mark.gpu + if res in ("L4", "H100"): + return [test_platform] + cuda_marks(res=res, num_cards=num_cards) + if res == "MI325": + return [test_platform] + rocm_marks(res=res, num_cards=num_cards) + raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325") + + +def npu_marks(*, res: str, num_cards: int): + """Get a collection of pytest marks to apply for `@npu_test`.""" + test_platform = pytest.mark.npu + if res == "A2": + test_resource = pytest.mark.A2 + elif res == "A3": + test_resource = pytest.mark.A3 + else: + # TODO: Currently we don't have various NPU card types defined + # Use None to skip resource-specific marking for unknown types + test_resource = None + + if num_cards == 1: + return [mark for mark in [test_platform, test_resource] if mark is not None] + else: + # Multiple cards scenario needs distributed_npu mark + test_distributed = pytest.mark.distributed_npu(num_cards=num_cards) + # TODO: add NPU support for `skipif_npu` marker + return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None] + + +def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1): + """ + Decorate a test for multiple hardware platforms with a single call. + Automatically wraps the test with @create_new_process_for_each_test() for distributed tests. + + Args: + res: Mapping from platform to resource type. Supported platforms/resources: + - cuda: L4, H100 + - rocm: MI325 + - npu: A2, A3 + num_cards: Number of cards required. Can be: + - int: same card count for all platforms (default: 1) + - dict: per-platform card count, e.g., {"cuda": 2, "rocm": 2} + + Example: + @hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards={"cuda": 2, "rocm": 2, "npu": 2}, + ) + def test_multi_platform(): + ... + """ + # Validate platforms + # Don't validate platform details in this decorator + for platform, _ in res.items(): + if platform not in ("cuda", "rocm", "npu"): + raise ValueError(f"Unsupported platform: {platform}") + + # Normalize num_cards + if isinstance(num_cards, int): + num_cards_dict = {platform: num_cards for platform in res.keys()} + else: + num_cards_dict = num_cards + for platform in num_cards_dict.keys(): + if platform not in res: + raise ValueError( + f"Platform '{platform}' in num_cards but not in res. Available platforms: {list(res.keys())}" + ) + for platform in res.keys(): + if platform not in num_cards_dict: + num_cards_dict[platform] = 1 + + # Collect marks from all platforms + all_marks: list[Callable[[Callable[_P, None]], Callable[_P, None]]] = [] + for platform, resource in res.items(): + cards = num_cards_dict[platform] + if platform == "cuda" or platform == "rocm": + marks = gpu_marks(res=resource, num_cards=cards) + elif platform == "npu": + marks = npu_marks(res=resource, num_cards=cards) + else: + raise ValueError(f"Unsupported platform: {platform}") + all_marks.extend(marks) + + create_new_process_flag = False + for cards in num_cards_dict.values(): + if cards > 1: + create_new_process_flag = True + break + + def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: + if create_new_process_flag: + # only for distributed tests + func = create_new_process_for_each_test()(f) + else: + func = f + for mark in reversed(all_marks): + func = mark(func) + return func + + return wrapper + + +class GPUMemoryMonitor: + """Poll global device memory usage via CUDA APIs.""" + + def __init__(self, device_index: int, interval: float = 0.05): + self.device_index = device_index + self.interval = interval + self._peak_used_mb = 0.0 + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + def monitor_loop() -> None: + while not self._stop_event.is_set(): + try: + with torch.cuda.device(self.device_index): + free_bytes, total_bytes = torch.cuda.mem_get_info() + used_mb = (total_bytes - free_bytes) / (1024**2) + self._peak_used_mb = max(self._peak_used_mb, used_mb) + except Exception: + pass + time.sleep(self.interval) + + self._thread = threading.Thread(target=monitor_loop, daemon=False) + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + self._stop_event.set() + self._thread.join(timeout=2.0) + + @property + def peak_used_mb(self) -> float: + fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2) + fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2) + return max(self._peak_used_mb, fallback_alloc, fallback_reserved) + + def __del__(self): + self.stop() diff --git a/tests/worker/test_gpu_generation_model_runner.py b/tests/worker/test_gpu_generation_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..25ed1ae861dfea523e35ebff991c2619c9ab770c --- /dev/null +++ b/tests/worker/test_gpu_generation_model_runner.py @@ -0,0 +1,75 @@ +import torch + +from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner + + +class _DummyInputBatch: + def __init__(self): + self.req_ids = ["req-1"] + self.req_id_to_index = {"req-1": 0} + self.num_reqs = 1 + self.vocab_size = 10 + + +def _make_runner(multimodal_outputs): + runner = object.__new__(GPUGenerationModelRunner) + runner.execute_model_state = ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + multimodal_outputs, + ) + runner.kv_connector_output = None + runner.input_batch = _DummyInputBatch() + runner.use_async_scheduling = False + runner.device = torch.device("cpu") + runner.supports_mm_inputs = False + return runner + + +def test_sample_tokens_tensor_output(): + multimodal_outputs = torch.randn(1, 2, 3) + runner = _make_runner(multimodal_outputs) + + output = GPUGenerationModelRunner.sample_tokens(runner) + + assert len(output.pooler_output) == 1 + assert output.pooler_output[0]["model_outputs"].shape == (2, 3) + + +def test_sample_tokens_list_output(): + multimodal_outputs = [torch.randn(2, 1)] + runner = _make_runner(multimodal_outputs) + + output = GPUGenerationModelRunner.sample_tokens(runner) + + assert len(output.pooler_output) == 1 + assert output.pooler_output[0]["model_outputs"].shape == (2, 1) + + +def test_sample_tokens_list_allows_none_output(): + multimodal_outputs = [None] + runner = _make_runner(multimodal_outputs) + + output = GPUGenerationModelRunner.sample_tokens(runner) + + assert len(output.pooler_output) == 1 + assert output.pooler_output[0]["model_outputs"] is None + + +def test_sample_tokens_dict_output(): + multimodal_outputs = {"audio": torch.randn(1, 4), "unused": None} + runner = _make_runner(multimodal_outputs) + + output = GPUGenerationModelRunner.sample_tokens(runner) + + assert len(output.pooler_output) == 1 + assert "audio" in output.pooler_output[0] + assert "unused" not in output.pooler_output[0] + assert output.pooler_output[0]["audio"].shape == (1, 4) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1adf227d1172e7f616f11881b9ef5b48dff494 --- /dev/null +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -0,0 +1,130 @@ +from contextlib import contextmanager +from types import SimpleNamespace + +import torch + +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + +class DummyBuffer: + """A minimal buffer wrapper that exposes the `.gpu` attribute.""" + + def __init__(self, t: torch.Tensor): + self.gpu = t + + +class DummyInputBatch: + """A minimal input batch that only provides `req_ids`.""" + + def __init__(self, req_ids): + self.req_ids = req_ids + + +class DummyReqState: + """A minimal request state container.""" + + pass + + +class DummyTalkerMTP(torch.nn.Module): + """A fake talker_mtp module for deterministic CPU testing.""" + + def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step): + # Deterministic behavior: + # - output embeds = input embeds + 1 + # - output codes = [[0], [1], ...] + bsz = req_embeds.shape[0] + new_embeds = req_embeds + 1.0 + codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1) + return new_embeds, codes + + +@contextmanager +def _noop_forward_context(*args, **kwargs): + """A no-op context manager to replace vLLM forward context in CPU tests.""" + yield + + +def _make_runner(req_ids=("r1", "r2"), hidden_size=4): + # Create an instance without calling OmniGPUModelRunner.__init__ + runner = object.__new__(OmniGPUModelRunner) + + # Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward + runner.input_batch = DummyInputBatch(list(req_ids)) + runner.requests = {rid: DummyReqState() for rid in req_ids} + + # query_start_loc.cpu[req_index] is used to locate the token position + # in the flattened `inputs_embeds`. + runner.query_start_loc = type("QSL", (), {})() + # Map: r1 -> offset 0, r2 -> offset 3 + runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32) + + bsz = len(req_ids) + runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64)) + runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + + runner.talker_mtp = DummyTalkerMTP() + runner.vllm_config = object() + + # Provide a minimal implementation that returns the expected 4-tuple. + def _determine_batch_execution_and_padding(**kwargs): + return None, object(), None, None, None + + runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding + + # Use the real merge method from OmniGPUModelRunner. + return runner + + +def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): + # Patch the module-level `set_forward_context` symbol used inside + # OmniGPUModelRunner._talker_mtp_forward. + import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4) + + def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_scheduled_tokens, use_cascade_attn): + batch_desc = SimpleNamespace(num_tokens=int(num_tokens)) + return (False, batch_desc, None, None, None) + + monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner))) + + # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds) + runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) + runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) + + # Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten + inputs_embeds = torch.zeros((6, 4), dtype=torch.float32) + + # Call the original implementation from OmniGPUModelRunner (no re-implementation) + OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds) + + # Validate embeds were written back (+1) + assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0])) + assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0])) + + # Validate per-request additional_information_cpu was updated + info_r1 = runner.requests["r1"].additional_information_cpu + info_r2 = runner.requests["r2"].additional_information_cpu + assert int(info_r1["code_predictor_codes"][0, 0]) == 0 + assert int(info_r2["code_predictor_codes"][0, 0]) == 1 + + +def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch): + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + inputs_embeds = torch.randn((2, 4)) + before = inputs_embeds.clone() + + OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds) + + # Ensure no changes were made + assert torch.allclose(inputs_embeds, before) diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c0f3dd3b5457354313fc230efef4d955732074 --- /dev/null +++ b/tools/pre_commit/check_pickle_imports.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys + +import regex as re + +# List of files (relative to repo root) that are allowed to import pickle or +# cloudpickle +# +# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST: +# The pickle and cloudpickle modules are known to be unsafe when deserializing +# data from potentially untrusted parties. They have resulted in multiple CVEs +# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly. +# Before adding new uses of pickle/cloudpickle, please consider safer +# alternatives like msgpack or pydantic that are already in use in vLLM. Only +# add to this list if absolutely necessary and after careful security review. +ALLOWED_FILES = { + "vllm_omni/entrypoints/omni_llm.py", + "tests/e2e/offline_inference/utils.py", + "tests/utils.py", + "vllm_omni/diffusion/distributed/group_coordinator.py", + "tests/diffusion/attention/test_attention_sp.py", +} + +PICKLE_RE = re.compile( + r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" + r"|from\s+(pickle|cloudpickle)\s+import\b)" +) + + +def scan_file(path: str) -> int: + with open(path, encoding="utf-8") as f: + for i, line in enumerate(f, 1): + if PICKLE_RE.match(line): + print( + f"{path}:{i}: " + "\033[91merror:\033[0m " # red color + "Found pickle/cloudpickle import" + ) + return 1 + return 0 + + +def main(): + returncode = 0 + for filename in sys.argv[1:]: + if filename in ALLOWED_FILES: + continue + returncode |= scan_file(filename) + return returncode + + +def test_regex(): + test_cases = [ + # Should match + ("import pickle", True), + ("import cloudpickle", True), + ("import pickle as pkl", True), + ("import cloudpickle as cpkl", True), + ("from pickle import *", True), + ("from cloudpickle import dumps", True), + ("from pickle import dumps, loads", True), + ("from cloudpickle import (dumps, loads)", True), + (" import pickle", True), + ("\timport cloudpickle", True), + ("from pickle import loads", True), + # Should not match + ("import somethingelse", False), + ("from somethingelse import pickle", False), + ("# import pickle", False), + ("print('import pickle')", False), + ("import pickleas as asdf", False), + ] + for i, (line, should_match) in enumerate(test_cases): + result = bool(PICKLE_RE.match(line)) + assert result == should_match, f"Test case {i} failed: '{line}' (expected {should_match}, got {result})" + print("All regex tests passed.") + + +if __name__ == "__main__": + if "--test-regex" in sys.argv: + test_regex() + else: + sys.exit(main()) diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25ea24b198a725c7ab25d8a06b81248b5215a057 --- /dev/null +++ b/vllm_omni/__init__.py @@ -0,0 +1,43 @@ +""" +vLLM-Omni: Multi-modality models inference and serving with +non-autoregressive structures. + +This package extends vLLM beyond traditional text-based, autoregressive +generation to support multi-modality models with non-autoregressive +structures and non-textual outputs. + +Architecture: +- 🟡 Modified: vLLM components modified for multimodal support +- 🔴 Added: New components for multimodal and non-autoregressive + processing +""" + +try: + from . import patch # noqa: F401 +except ModuleNotFoundError as exc: # pragma: no cover - optional dependency + if exc.name != "vllm": + raise + # Allow importing vllm_omni without vllm (e.g., documentation builds) + patch = None # type: ignore + + +from .config import OmniModelConfig +from .entrypoints.async_omni import AsyncOmni + +# Main entry points +from .entrypoints.omni import Omni + +from .version import __version__, __version_tuple__ # isort:skip + + +__all__ = [ + "__version__", + "__version_tuple__", + # Main components + "Omni", + "AsyncOmni", + # Configuration + "OmniModelConfig", + # All other components are available through their respective modules + # processors.*, schedulers.*, executors.*, etc. +] diff --git a/vllm_omni/__pycache__/outputs.cpython-314.pyc.2372032969296 b/vllm_omni/__pycache__/outputs.cpython-314.pyc.2372032969296 new file mode 100644 index 0000000000000000000000000000000000000000..80196c1f2934cef04945b79bdce02599b58c7327 Binary files /dev/null and b/vllm_omni/__pycache__/outputs.cpython-314.pyc.2372032969296 differ diff --git a/vllm_omni/assets/video.py b/vllm_omni/assets/video.py new file mode 100644 index 0000000000000000000000000000000000000000..98b1f7e4e29259e575762332511371e4b9d5b6b8 --- /dev/null +++ b/vllm_omni/assets/video.py @@ -0,0 +1,16 @@ +import librosa +import numpy as np +from vllm.assets.video import VideoAsset + + +def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray: + """This function extracts the audio from a video file path and returns the audio as a numpy array. + Args: + path: The path to the video file. + Returns: + The audio as a numpy array. + """ + if not path: + path = VideoAsset(name="baby_reading").video_path + audio_signal, sr = librosa.load(path, sr=sampling_rate) + return audio_signal diff --git a/vllm_omni/benchmarks/data_modules/__init__.py b/vllm_omni/benchmarks/data_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/benchmarks/data_modules/random_multi_modal_dataset.py b/vllm_omni/benchmarks/data_modules/random_multi_modal_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..14ba86cc4cbe8f60d7bdbf641e125282ab15d457 --- /dev/null +++ b/vllm_omni/benchmarks/data_modules/random_multi_modal_dataset.py @@ -0,0 +1,152 @@ +import base64 +import io +import logging +from collections.abc import Mapping +from typing import Any + +import numpy as np +import soundfile as sf +import torch +from vllm.benchmarks.datasets import RandomMultiModalDataset, process_image, process_video + +logger = logging.getLogger(__name__) + + +def process_audio(audio: Any) -> Mapping[str, Any]: + """ + Process a single audio input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw audio bytes: - Expects a dict with a 'bytes' key + containing raw audio data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the audio URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(audio, dict) and "bytes" in audio: + audio_bytes = audio["bytes"] + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + return { + "type": "audio_url", + "audio_url": {"url": f"data:audio/mpeg;base64,{audio_base64}"}, + } + if isinstance(audio, str): + audio_url = audio if audio.startswith(("http://", "https://", "file://")) else f"file://{audio}" + return {"type": "audio_url", "audio_url": {"url": audio_url}} + + raise ValueError( + f"Invalid audio input {audio}. Must be a string of local path/remote url, " + f"or a dictionary with raw audio bytes in the form of `{{'bytes': raw_audio_bytes}}`." + ) + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- +class OmniRandomMultiModalDataset(RandomMultiModalDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def generate_synthetic_audio( + self, + duration: int, # seconds + num_channels: int, # 1:Mono,2:Stereo 5:5.1 surround sound + ) -> dict[str, Any]: + """Generate synthetic audio with random values. + Default use 48000Hz. + """ + sample_rate = 48000 + num_samples = int(sample_rate * duration) + audio_data = self._rng.uniform(-0.5, 0.5, (num_samples, num_channels)) + audio_data = np.clip(audio_data, -1.0, 1.0) + audio_tensor = torch.FloatTensor(audio_data.T) + audio_np = audio_tensor.numpy() + + buffer = io.BytesIO() + + sf.write(buffer, audio_np.T, sample_rate, format="wav") + + buffer.seek(0) + audio_bytes = buffer.read() + buffer.close() + return { + "bytes": audio_bytes, + } + + def generate_mm_item( + self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image(self.generate_synthetic_image(mm_item_config[1], mm_item_config[0])) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video(self.generate_synthetic_video(mm_item_config[1], mm_item_config[0], mm_item_config[2])) + elif self.map_config_to_modality(mm_item_config) == "audio": + return process_audio(self.generate_synthetic_audio(mm_item_config[1], mm_item_config[2])) + else: + raise ValueError(f"Invalid multimodal item configuration: {mm_item_config}") + + def generate_synthetic_video(self, width: int, height: int, num_frames: int) -> Any: + """Generate synthetic video with random values.""" + import imageio + + video_data = self._rng.integers( + 0, + 256, + (num_frames, height, width, 3), + dtype=np.uint8, + ) + buffer = io.BytesIO() + writer_kwargs = { + "format": "mp4", + "fps": 30, + "codec": "libx264", + "quality": 7, + "pixelformat": "yuv420p", + "macro_block_size": 16, + "ffmpeg_params": [ + "-preset", + "medium", + "-crf", + "23", + "-movflags", + "+faststart", + "-pix_fmt", + "yuv420p", + "-vf", + f"scale={width}:{height}", + ], + } + + with imageio.get_writer(buffer, **writer_kwargs) as writer: + for frame_idx in range(num_frames): + writer.append_data(video_data[frame_idx]) + buffer.seek(0) + video_bytes = buffer.read() + + return { + "bytes": video_bytes, + } + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[0] == 0: + return "audio" + elif config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") diff --git a/vllm_omni/benchmarks/metrics/__init__.py b/vllm_omni/benchmarks/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/benchmarks/metrics/metrics.py b/vllm_omni/benchmarks/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f404a12f8e66bfdc40ee5e15af2cef51ce8e64f2 --- /dev/null +++ b/vllm_omni/benchmarks/metrics/metrics.py @@ -0,0 +1,330 @@ +import warnings +from dataclasses import dataclass + +import numpy as np +from transformers import PreTrainedTokenizerBase +from vllm.benchmarks.datasets import SampleRequest +from vllm.benchmarks.lib.endpoint_request_func import RequestFuncOutput +from vllm.benchmarks.serve import MILLISECONDS_TO_SECONDS_CONVERSION, TERM_PLOTLIB_AVAILABLE, BenchmarkMetrics, TaskType + + +@dataclass +class MultiModalsBenchmarkMetrics(BenchmarkMetrics): + mean_audio_ttfp_ms: float = 0.0 + median_audio_ttfp_ms: float = 0.0 + std_audio_ttfp_ms: float = 0.0 + percentiles_audio_ttfp_ms: list[tuple[float, float]] = None + total_audio_duration_ms: float = 0.0 + total_audio_frames: int = 0 + audio_throughput: float = 0.0 + mean_audio_rtf: float = 0.0 + median_audio_rtf: float = 0.0 + std_audio_rtf: float = 0.0 + percentiles_audio_rtf: list[tuple[float, float]] = None + + +def print_metrics( + task_type, + selected_percentile_metrics, + max_concurrency, + request_rate, + benchmark_duration, + goodput_config_dict, + metrics: MultiModalsBenchmarkMetrics, +): + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10}".format("Failed requests:", metrics.failed)) + if max_concurrency is not None: + print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) + if request_rate != float("inf"): + print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) + if isinstance(metrics, MultiModalsBenchmarkMetrics): + print("{:<40} {:<10.2f}".format("Peak concurrent requests:", metrics.max_concurrent_requests)) + if task_type != TaskType.GENERATION or "e2el" in selected_percentile_metrics: + process_one_metric("e2el", metrics) + print_text_metrics(task_type, selected_percentile_metrics, metrics) + if task_type == TaskType.GENERATION: + print_audio_metrics(selected_percentile_metrics, metrics) + print("=" * 50) + + +def print_text_metrics(task_type, selected_percentile_metrics, metrics: MultiModalsBenchmarkMetrics): + print("{s:{c}^{n}}".format(s=" Text Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + if isinstance(metrics, MultiModalsBenchmarkMetrics): + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Peak output token throughput (tok/s):", metrics.max_output_tokens_per_s)) + print("{:<40} {:<10.2f}".format("Peak concurrent requests:", metrics.max_concurrent_requests)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput)) + + if task_type == TaskType.GENERATION: + for metric in selected_percentile_metrics: + if metric == "e2el": + continue + if not metric.startswith("audio"): + process_one_metric(metric, metrics) + + +def print_audio_metrics(selected_percentile_metrics, metrics: MultiModalsBenchmarkMetrics): + print("{s:{c}^{n}}".format(s=" Audio Result ", n=50, c="=")) + print("{:<40} {:<10.2f}".format("Total audio duration generated(s):", metrics.total_audio_duration_ms)) + print("{:<40} {:<10}".format("Total audio frames generated:", metrics.total_audio_frames)) + print("{:<40} {:<10.2f}".format("Audio throughput(audio duration/s):", metrics.audio_throughput)) + for metric in selected_percentile_metrics: + if metric.startswith("audio"): + process_one_metric(metric, metrics) + + +def process_one_metric( + metric_attribute_name: str, + metrics: MultiModalsBenchmarkMetrics, +): + metric_header_map = { + "ttft": "Time to First Token", + "tpot": "Time per Output Token (excl. 1st token)", + "itl": "Inter-token Latency", + "e2el": "End-to-end Latency", + "audio_ttfp": "Time to First Packet", + "audio_rtf": "Real Time Factor", + } + + header = metric_header_map.get(metric_attribute_name, metric_attribute_name) + print("{s:{c}^{n}}".format(s=header, n=50, c="-")) + + is_audio_rtf = metric_attribute_name == "audio_rtf" + + suffix = "" if is_audio_rtf else "_ms" + unit_suffix = "" if is_audio_rtf else " (ms)" + + mean_attr_name = f"mean_{metric_attribute_name}{suffix}" + mean_value = getattr(metrics, mean_attr_name, 0.0) + print(f"{f'Mean {metric_attribute_name.upper()}{unit_suffix}:':<40} {mean_value:<10.2f}") + + median_attr_name = f"median_{metric_attribute_name}{suffix}" + median_value = getattr(metrics, median_attr_name, 0.0) + print(f"{f'Median {metric_attribute_name.upper()}{unit_suffix}:':<40} {median_value:<10.2f}") + + percentiles_attr_name = f"percentiles_{metric_attribute_name}{suffix}" + percentiles = getattr(metrics, percentiles_attr_name, []) + + for percentile, value in percentiles: + p_str = str(int(percentile)) if percentile.is_integer() else str(percentile) + label = f"P{p_str} {metric_attribute_name.upper()}{unit_suffix}:" + print(f"{label:<40} {value:<10.2f}") + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], + task_type, + selected_percentile_metrics, + max_concurrency, + request_rate, + benchmark_duration, +) -> tuple[BenchmarkMetrics, list[int]]: + """Calculate the metrics for the benchmark. + + Args: + input_requests: The input requests. + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + tokenizer: The tokenizer to use. + selected_percentiles: The percentiles to select. + goodput_config_dict: The goodput configuration. + + Returns: + A tuple of the benchmark metrics and the actual output lengths. + """ + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + audio_ttfps: list[float] = [] + audio_rtfs: list[float] = [] + audio_duration: list[float] = [] + audio_frames: list[int] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_tokens + + if not output_len: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len(tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + audio_ttfps.append(getattr(outputs[i], "audio_ttfp", 0.0)) + audio_rtfs.append(getattr(outputs[i], "audio_rtf", 0.0)) + audio_duration.append(getattr(outputs[i], "audio_duration", 0.0)) + audio_frames.append(getattr(outputs[i], "audio_frames", 0.0)) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION) + if "audio_ttft" in goodput_config_dict: + valid_metrics.append(audio_ttfps) + slo_values.append(goodput_config_dict["audio_ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration on the benchmark arguments.", + stacklevel=2, + ) + + # Calculate max output tokens per second metric + max_output_tokens_per_s = 0.0 + max_concurrent_requests = 0 + + # Find the time range across all successful requests + successful_outputs = [output for output in outputs if output.success] + failed_outputs = [output for output in outputs if not output.success] + if successful_outputs: + min_start_time = min(output.start_time for output in successful_outputs) + max_end_time = max(output.start_time + output.latency for output in successful_outputs) + + # Create second buckets (ceiling to ensure we capture all time) + duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1 + tokens_per_second = np.zeros(duration_seconds) + concurrent_requests_per_second = np.zeros(duration_seconds) + + for i, output in enumerate(successful_outputs): + # Calculate token generation timestamp using + # start_time, ttft, and itl + token_times = [output.start_time + output.ttft] + current_time = token_times[0] + for itl_value in output.itl: + current_time += itl_value + token_times.append(current_time) + + # Add tokens to second buckets + for token_time in token_times: + second_bucket = int(token_time - min_start_time) + if 0 <= second_bucket < duration_seconds: + tokens_per_second[second_bucket] += 1 + + # Track concurrent requests for each second this request was active + request_start_second = int(output.start_time - min_start_time) + request_end_second = int((output.start_time + output.latency) - min_start_time) + + for second in range(request_start_second, request_end_second + 1): + concurrent_requests_per_second[second] += 1 + + # Find the maximum tokens per second and corresponding + # concurrent requests + if len(tokens_per_second) > 0: + max_output_tokens_per_s = float(np.max(tokens_per_second)) + max_concurrent_requests = int(np.max(concurrent_requests_per_second)) + + if TERM_PLOTLIB_AVAILABLE: + import termplotlib as tpl + + fig = tpl.figure() + fig.plot( + np.arange(len(tokens_per_second)), + tokens_per_second, + title="Output tokens per second", + ) + fig.plot( + np.arange(len(concurrent_requests_per_second)), + concurrent_requests_per_second, + title="Concurrent requests per second", + ) + fig.show() + else: + print("tip: install termplotlib and gnuplot to plot the metrics") + + metrics = MultiModalsBenchmarkMetrics( + completed=completed, + failed=len(failed_outputs), + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by the endpoint + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles], + mean_audio_ttfp_ms=np.mean(audio_ttfps or 0) * 1000, + std_audio_ttfp_ms=np.std(audio_ttfps or 0) * 1000, + median_audio_ttfp_ms=np.median(audio_ttfps or 0) * 1000, + percentiles_audio_ttfp_ms=[(p, np.percentile(audio_ttfps or 0, p) * 1000) for p in selected_percentiles], + total_audio_duration_ms=sum(audio_duration), + total_audio_frames=sum(audio_frames), + audio_throughput=sum(audio_duration) / dur_s, + mean_audio_rtf=np.mean(audio_rtfs or 0), + std_audio_rtf=np.std(audio_rtfs or 0), + median_audio_rtf=np.median(audio_rtfs or 0), + percentiles_audio_rtf=[(p, np.percentile(audio_rtfs or 0, p)) for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles], + max_output_tokens_per_s=max_output_tokens_per_s, + max_concurrent_requests=max_concurrent_requests, + ) + print_metrics( + task_type, + selected_percentile_metrics, + max_concurrency, + request_rate, + benchmark_duration, + goodput_config_dict, + metrics, + ) + return metrics, actual_output_lens diff --git a/vllm_omni/benchmarks/patch/__init__.py b/vllm_omni/benchmarks/patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2e5c1955af0e8818e3b3296998cce570085e87 --- /dev/null +++ b/vllm_omni/benchmarks/patch/patch.py @@ -0,0 +1,538 @@ +import asyncio +import base64 +import contextlib +import io +import json +import os +import random +import sys +import time +import traceback +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime +from typing import Literal + +import aiohttp +from pydub import AudioSegment +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase +from vllm.benchmarks import datasets +from vllm.benchmarks.datasets import SampleRequest +from vllm.benchmarks.lib.endpoint_request_func import ( + ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput, + StreamedResponseHandler, + _get_chat_content, + _update_headers_common, + _update_payload_common, + _validate_api_url, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) +from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset + +get_samples_old = datasets.get_samples + + +def get_samples(args, tokenizer): + if args.backend not in ["openai-chat-omni"]: + raise ValueError("benchmark is only supported on 'openai-chat-omni' backend.") + if args.dataset_name == "random-mm": + dataset = OmniRandomMultiModalDataset(random_seed=args.seed, dataset_path=args.dataset_path) + input_requests = dataset.sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ) + return input_requests + else: + return get_samples_old(args, tokenizer) + + +datasets.get_samples = get_samples + + +@dataclass +class MixRequestFuncOutput(RequestFuncOutput): + audio_ttfp: float = 0.0 + audio_duration: float = 0.0 + audio_frames: int = 0 + audio_rtf: float = 0.0 + + +async def async_request_openai_chat_omni_completions( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, + mm_position: Literal["first", "last"] = "last", +) -> MixRequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions") + + content = _get_chat_content(request_func_input, mm_position=mm_position) + + payload = { + "model": request_func_input.model_name if request_func_input.model_name else request_func_input.model, + "messages": [ + {"role": "user", "content": content}, + ], + "temperature": 0.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + _update_payload_common(payload, request_func_input) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + _update_headers_common(headers, request_func_input) + + output = MixRequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + generated_audio = None + ttft = 0.0 + st = time.perf_counter() + output.start_time = st + most_recent_timestamp = st + audio_generate_time = 0.0 + audio_first_timestamp = st + try: + async with session.post(url=api_url, json=payload, headers=headers) as response: + if response.status == 200: + handler = StreamedResponseHandler() + async for chunk_bytes in response.content.iter_any(): + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue + + chunk = message.removeprefix("data: ") + + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + if choices := data.get("choices"): + modality = data.get("modality") + content = choices[0]["delta"].get("content") + if modality == "text": + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + else: + output.itl.append(timestamp - most_recent_timestamp) + generated_text += content or "" + elif modality == "audio": + if output.audio_ttfp == 0.0: + audio_first_timestamp = timestamp + output.audio_ttfp = timestamp - st + audio_generate_time = timestamp - audio_first_timestamp + if content != "": + audio_bytes = base64.b64decode(content) + seg = AudioSegment.from_file(io.BytesIO(audio_bytes)) + if seg is not None: + if generated_audio is None: + generated_audio = seg + else: + generated_audio = generated_audio + seg + + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + most_recent_timestamp = timestamp + + output.generated_text = generated_text + if generated_audio is not None: + output.audio_duration = len(generated_audio) / 1000.0 + frame_width = generated_audio.frame_width + if frame_width > 0: + output.audio_frames = len(generated_audio.raw_data) // frame_width + else: + output.audio_frames = 0 + logger.warning("Audio frame width is zero") + audio_duration = output.audio_duration + if audio_duration > 0: + output.audio_rtf = audio_generate_time / output.audio_duration + else: + output.audio_rtf = 0 + logger.warning("Audio duration is zero") + + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + logger.error(f"ERROR: send request failed, reason is: {output.error}") + + if pbar: + pbar.update(1) + return output + + +ASYNC_REQUEST_FUNCS["openai-chat-omni"] = async_request_openai_chat_omni_completions +if "openai-chat-omni" not in OPENAI_COMPATIBLE_BACKENDS: + OPENAI_COMPATIBLE_BACKENDS.append("openai-chat-omni") + +# ruff: noqa: E402 +# Prevent import order from causing patch failures +from vllm.benchmarks import serve +from vllm.benchmarks.serve import TaskType, calculate_metrics_for_embeddings, get_request, wait_for_endpoint + +from vllm_omni.benchmarks.metrics.metrics import MultiModalsBenchmarkMetrics, calculate_metrics + +# ruff: noqa: E402 + +benchmark_old = serve.benchmark + + +async def benchmark( + task_type: TaskType, + endpoint_type: str, + api_url: str, + base_url: str, + model_id: str, + model_name: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[SampleRequest], + logprobs: int | None, + request_rate: float, + burstiness: float, + disable_tqdm: bool, + num_warmups: int, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: int | None, + lora_modules: Iterable[str] | None, + extra_headers: dict | None, + extra_body: dict | None, + ramp_up_strategy: Literal["linear", "exponential"] | None = None, + ramp_up_start_rps: int | None = None, + ramp_up_end_rps: int | None = None, + ready_check_timeout_sec: int = 600, +): + try: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + except KeyError: + raise ValueError(f"Unknown backend: {endpoint_type}") from None + + # Reuses connections across requests to reduce TLS handshake overhead. + connector = aiohttp.TCPConnector( + limit=max_concurrency or 0, + limit_per_host=max_concurrency or 0, + ttl_dns_cache=300, + use_dns_cache=True, + keepalive_timeout=60, + enable_cleanup_closed=True, + force_close=False, + ssl=("https://" in api_url), + ) + + session = aiohttp.ClientSession( + connector=connector, + trust_env=True, + timeout=aiohttp.ClientTimeout(total=6 * 60 * 60), + ) + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) + + assert ( + test_mm_content is None + or isinstance(test_mm_content, dict) + or (isinstance(test_mm_content, list) and all(isinstance(item, dict) for item in test_mm_content)) + ), "multi_modal_data must be a dict or list[dict]" + test_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + if ready_check_timeout_sec > 0: + test_output = await wait_for_endpoint( + request_func, + test_input, + session, + timeout_seconds=ready_check_timeout_sec, + ) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark " + "arguments are correctly specified. " + f"Error: {test_output.error}" + ) + else: + print("Initial test run completed.") + else: + print("Skipping endpoint ready check.") + + if num_warmups > 0: + print(f"Warming up with {num_warmups} requests...") + warmup_pbar = None if disable_tqdm else tqdm(total=num_warmups) + warmup_semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext() + warmup_tasks = [] + + async def warmup_limited_request_func(): + async with warmup_semaphore: + return await request_func(request_func_input=test_input, session=session, pbar=warmup_pbar) + + for _ in range(num_warmups): + request_task = asyncio.create_task(warmup_limited_request_func()) + warmup_tasks.append(request_task) + _ = await asyncio.gather(*warmup_tasks) + + if warmup_pbar is not None: + warmup_pbar.close() + print("Warmup run completed.") + + print("Starting main benchmark run...") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + ) + profile_output = await request_func(request_func_input=profile_input, session=session) + if profile_output.success: + print("Profiler started") + + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" + + if ramp_up_strategy is not None: + print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") + print( + f"Will increase RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over the duration of the benchmark." + ) + else: + print(f"Traffic request rate: {request_rate}") + + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext() + + async def limited_request_func(request_func_input, session, pbar): + async with semaphore: + return await request_func(request_func_input=request_func_input, session=session, pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + + rps_change_events = [] + last_int_rps = -1 + if ramp_up_strategy is not None and ramp_up_start_rps is not None: + last_int_rps = ramp_up_start_rps + rps_change_events.append( + { + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + } + ) + + async for request, current_request_rate in get_request( + input_requests, + request_rate, + burstiness, + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + ): + if ramp_up_strategy is not None: + current_int_rps = int(current_request_rate) + if current_int_rps > last_int_rps: + timestamp = datetime.now().isoformat() + for rps_val in range(last_int_rps + 1, current_int_rps + 1): + rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) + last_int_rps = current_int_rps + prompt, prompt_len, output_len, mm_content, request_id = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + request.request_id, + ) + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + request_id=request_id, + ) + tasks.append( + asyncio.create_task(limited_request_func(request_func_input=request_func_input, session=session, pbar=pbar)) + ) + outputs: list[MixRequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + if task_type == TaskType.GENERATION: + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + task_type=task_type, + selected_percentile_metrics=selected_percentile_metrics, + max_concurrency=max_concurrency, + request_rate=request_rate, + benchmark_duration=benchmark_duration, + ) + else: + metrics = calculate_metrics_for_embeddings( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + ) + actual_output_lens = 0 + + if isinstance(metrics, MultiModalsBenchmarkMetrics): + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "failed": metrics.failed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "max_output_tokens_per_s": metrics.max_output_tokens_per_s, + "max_concurrent_requests": metrics.max_concurrent_requests, + } + else: + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "request_throughput": metrics.request_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "errors": [output.error for output in outputs], + } + + if rps_change_events: + result["rps_change_events"] = rps_change_events + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + is_audio_rtf = metric_attribute_name == "audio_rtf" + + suffix = "" if is_audio_rtf else "_ms" + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}{suffix}"): + p_word = str(int(p)) if int(p) == p else str(p) + result[f"p{p_word}_{metric_attribute_name}{suffix}"] = value + + if task_type == TaskType.GENERATION: + for metric in selected_percentile_metrics: + process_one_metric(metric) + else: + process_one_metric("e2el") + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + ) + profile_output = await request_func(request_func_input=profile_input, session=session) + if profile_output.success: + print("Profiler stopped") + + await session.close() + return result + + +serve.benchmark = benchmark diff --git a/vllm_omni/benchmarks/serve.py b/vllm_omni/benchmarks/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..fe946036931625f09d59a3faccceb2b6d4f9ea4c --- /dev/null +++ b/vllm_omni/benchmarks/serve.py @@ -0,0 +1,9 @@ +import argparse +import asyncio +from typing import Any + +from vllm.benchmarks.serve import main_async + + +def main(args: argparse.Namespace) -> dict[str, Any]: + return asyncio.run(main_async(args)) diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2db6f4273c8556425ef32b6b8f4719e80e72092 --- /dev/null +++ b/vllm_omni/config/__init__.py @@ -0,0 +1,11 @@ +""" +Configuration module for vLLM-Omni. +""" + +from vllm_omni.config.lora import LoRAConfig +from vllm_omni.config.model import OmniModelConfig + +__all__ = [ + "OmniModelConfig", + "LoRAConfig", +] diff --git a/vllm_omni/config/lora.py b/vllm_omni/config/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..00aba2e16b6c4f39625c22cf8d83f61e1727cd4e --- /dev/null +++ b/vllm_omni/config/lora.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# for now, it suffices to use vLLM's implementation directly +# as this is a user-facing variable, defined here to so that user can directly import LoRAConfig from vllm_omni +from vllm.config.lora import LoRAConfig + +__all__ = ["LoRAConfig"] diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py new file mode 100644 index 0000000000000000000000000000000000000000..dad07dc226ac60062af77eff96131adc49fe44b7 --- /dev/null +++ b/vllm_omni/config/model.py @@ -0,0 +1,305 @@ +import warnings +from dataclasses import field +from typing import Any + +import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass +from vllm.config import ModelConfig, config +from vllm.config.model import ( + _RUNNER_CONVERTS, + _get_and_verify_dtype, + get_served_model_name, +) +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig +from vllm.config.pooler import PoolerConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.transformers_utils.config import ( + get_config, + get_hf_image_processor_config, + get_hf_text_config, + get_pooling_config, +) +from vllm.transformers_utils.gguf_utils import is_gguf, maybe_patch_hf_config_from_gguf +from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.v1.attention.backends.registry import AttentionBackendEnum + +import vllm_omni.model_executor.models as me_models + +logger = init_logger(__name__) + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class OmniModelConfig(ModelConfig): + """Configuration for Omni models, extending the base ModelConfig. + + This configuration class extends the base vLLM ModelConfig with + omni-specific fields for multi-stage pipeline processing. + + Attributes: + stage_id: Identifier for the stage in a multi-stage pipeline (default: 0) + async_chunk: If set to True, perform async chunk + model_stage: Stage type identifier, e.g., "thinker" or "talker" + (default: "thinker") + model_arch: Model architecture name + (default: "Qwen2_5OmniForConditionalGeneration") + engine_output_type: Optional output type specification for the engine. + Used to route outputs to appropriate processors (e.g., "image", + "audio", "latents"). If None, output type is inferred. + stage_connector_config: Stage connector configuration dictionary. + Contains "name" (connector name), "extra" (extra connector config). + + Example: + >>> config = OmniModelConfig( + ... stage_id=0, + ... model_stage="thinker", + ... model_arch="Qwen2_5OmniForConditionalGeneration" + ... ) + """ + + stage_id: int = 0 + async_chunk: bool = False + model_stage: str = "thinker" + model_arch: str = "Qwen2_5OmniForConditionalGeneration" + engine_output_type: str | None = None + hf_config_name: str | None = None + custom_process_next_stage_input_func: str | None = None + stage_connector_config: dict[str, Any] = field( + default_factory=lambda: { + "name": "SharedMemoryConnector", + "extra": {}, + } + ) + omni_kv_config: dict | None = None + + @property + def registry(self): + return me_models.OmniModelRegistry + + @property + def architectures(self) -> list[str]: + return [self.model_arch] + + def draw_hf_text_config(self): + # transformers' get_text_config method is used to get the text config from thinker_config. + # to handle the case that each model stage has their own text config, + # we need to draw the text config from the corresponding model stage. + if self.hf_config_name is None: + return get_hf_text_config(self.hf_config) + try: + # Try to get the stage-specific config (e.g., thinker_config, talker_config) + stage_config = getattr(self.hf_config, self.hf_config_name) + return stage_config.get_text_config() + except AttributeError: + # Fallback: if the attribute doesn't exist, use the default get_hf_text_config + logger.warning( + f"Config attribute '{self.hf_config_name}' not found in hf_config, " + "falling back to default get_hf_text_config" + ) + return get_hf_text_config(self.hf_config) + + def __post_init__( + self, + # Multimodal config init vars + limit_mm_per_prompt: dict[str, int | dict[str, int]] | None, + enable_mm_embeds: bool | None, + media_io_kwargs: dict[str, dict[str, Any]] | None, + mm_processor_kwargs: dict[str, Any] | None, + mm_processor_cache_gb: float | None, + mm_processor_cache_type: MMCacheType | None, + mm_shm_cache_max_object_size_mb: int | None, + mm_encoder_only: bool | None, + mm_encoder_tp_mode: MMEncoderTPMode | None, + mm_encoder_attn_backend: AttentionBackendEnum | str | None, + interleave_mm_strings: bool | None, + skip_mm_profiling: bool | None, + video_pruning_rate: float | None, + ) -> None: + # Keep set served_model_name before maybe_model_redirect(self.model) + self.served_model_name = get_served_model_name(self.model, self.served_model_name) + self.model = maybe_model_redirect(self.model) + # The tokenizer is consistent with the model by default. + if self.tokenizer is None: + self.tokenizer = self.model + if self.tokenizer_revision is None: + self.tokenizer_revision = self.revision + self.tokenizer = maybe_model_redirect(self.tokenizer) + + if isinstance(self.hf_config_path, str): + self.hf_config_path = maybe_model_redirect(self.hf_config_path) + + if callable(self.hf_overrides): + hf_overrides_kw = {} + hf_overrides_fn = self.hf_overrides + dict_overrides: dict[str, Any] = {} + else: + # Separate dict overrides from flat ones + # We'll determine how to apply dict overrides after loading the config + hf_overrides_kw = {} + dict_overrides = {} + for key, value in self.hf_overrides.items(): + if isinstance(value, dict): + dict_overrides[key] = value + else: + hf_overrides_kw[key] = value + hf_overrides_fn = None + + self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) + + if self.override_attention_dtype is not None and not current_platform.is_rocm(): + warnings.warn( + "override-attention-dtype is set but not using ROCm platform", + stacklevel=2, + ) + + if self.enable_sleep_mode and not current_platform.is_sleep_mode_available(): + raise ValueError("Sleep mode is not supported on current platform.") + + hf_config = get_config( + self.hf_config_path or self.model, + self.trust_remote_code, + self.revision, + self.code_revision, + self.config_format, + hf_overrides_kw=hf_overrides_kw, + hf_overrides_fn=hf_overrides_fn, + ) + hf_config = maybe_patch_hf_config_from_gguf( + self.model, + hf_config, + ) + + self.hf_config = hf_config + if dict_overrides: + self._apply_dict_overrides(hf_config, dict_overrides) + self.hf_text_config = self.draw_hf_text_config() + self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) + self.encoder_config = self._get_encoder_config() + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, hf_token=self.hf_token, revision=self.revision + ) + self.model_arch_config = self.get_model_arch_config() + + if self.convert == "mm_encoder_only": + logger.warning_once( + "`--convert mm_encoder_only` is deprecated and " + "will be removed in v0.15. " + "Please use --mm-encoder-only` instead." + ) + mm_encoder_only = True + self.convert = "none" + + architectures = self.architectures + registry = self.registry + is_generative_model = registry.is_text_generation_model(architectures, self) + is_pooling_model = registry.is_pooling_model(architectures, self) + + self.runner_type = self._get_runner_type(architectures, self.runner) + self.convert_type = self._get_convert_type(architectures, self.runner_type, self.convert) + + if self.runner_type == "generate" and not is_generative_model: + generate_converts = _RUNNER_CONVERTS["generate"] + if self.convert_type not in generate_converts: + # Currently we don't have any converters for generative models + raise ValueError("This model does not support `--runner generate`.") + if self.runner_type == "pooling" and not is_pooling_model: + pooling_converts = _RUNNER_CONVERTS["pooling"] + if self.convert_type not in pooling_converts: + convert_option = "<" + "|".join(pooling_converts) + ">" + raise ValueError( + "This model does not support `--runner pooling`. " + f"You can pass `--convert {convert_option} to adapt " + "it into a pooling model." + ) + + # Note: Initialize these attributes early because transformers fallback + # may fail to load dynamic modules in child processes + model_info, arch = registry.inspect_model_cls(architectures, self) + self._model_info = model_info + self._architecture = arch + logger.info("Resolved architecture: %s", arch) + + # Init pooler config if needed + if self.runner_type == "pooling": + if self.pooler_config is None: + self.pooler_config = PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + # Only set values that are not overridden by the user + for k, v in base_config.items(): + if getattr(self.pooler_config, k) is None: + setattr(self.pooler_config, k, v) + + default_seq_pooling_type = self._model_info.default_seq_pooling_type + if self.pooler_config.seq_pooling_type is None: + self.pooler_config.seq_pooling_type = default_seq_pooling_type + default_tok_pooling_type = self._model_info.default_tok_pooling_type + if self.pooler_config.tok_pooling_type is None: + self.pooler_config.tok_pooling_type = default_tok_pooling_type + + self.dtype: torch.dtype = _get_and_verify_dtype( + self.model, + self.hf_config, + self.dtype, + is_pooling_model=self.runner_type == "pooling", + revision=self.revision, + ) + + self.original_max_model_len = self.max_model_len + self.max_model_len = self.get_and_verify_max_len(self.max_model_len) + + if self.is_encoder_decoder: + self.mm_processor_cache_gb = 0 + logger.info("Encoder-decoder model detected, disabling mm processor cache.") + + # Init multimodal config if needed + if self._model_info.supports_multimodal: + if mm_encoder_tp_mode == "data" and not self._model_info.supports_multimodal_encoder_tp_data: + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`." + ) + mm_encoder_tp_mode = "weights" + + mm_config_kwargs = dict( + limit_per_prompt=limit_mm_per_prompt, + enable_mm_embeds=enable_mm_embeds, + media_io_kwargs=media_io_kwargs, + mm_processor_kwargs=mm_processor_kwargs, + mm_processor_cache_gb=mm_processor_cache_gb, + mm_processor_cache_type=mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, + mm_encoder_only=mm_encoder_only, + mm_encoder_tp_mode=mm_encoder_tp_mode, + mm_encoder_attn_backend=mm_encoder_attn_backend, + interleave_mm_strings=interleave_mm_strings, + skip_mm_profiling=skip_mm_profiling, + video_pruning_rate=video_pruning_rate, + ) + + mm_config_kwargs = {k: v for k, v in mm_config_kwargs.items() if v is not None} + + self.multimodal_config = MultiModalConfig(**mm_config_kwargs) + + # Multimodal GGUF models must use original repo for mm processing + if is_gguf(self.tokenizer) and self.is_multimodal_model: + raise ValueError( + "Loading a multimodal GGUF model needs to use original " + "tokenizer. Please specify the unquantized hf model's " + "repo name or path using the --tokenizer argument." + ) + + if self.disable_sliding_window: + # Set after get_and_verify_max_len to ensure that max_model_len + # can be correctly capped to sliding window size + self.hf_text_config.sliding_window = None + + # Avoid running try_verify_and_update_config multiple times + self.config_updated = False + self._try_verify_and_update_model_config() + self._verify_quantization() + self._verify_cuda_graph() + self._verify_bnb_config() diff --git a/vllm_omni/core/__init__.py b/vllm_omni/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/core/sched/__init__.py b/vllm_omni/core/sched/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf18d07acf5d526696d49008bf139cc26019499 --- /dev/null +++ b/vllm_omni/core/sched/__init__.py @@ -0,0 +1,13 @@ +""" +Scheduling components for vLLM-Omni. +""" + +from .omni_ar_scheduler import OmniARScheduler +from .omni_generation_scheduler import OmniGenerationScheduler +from .output import OmniNewRequestData + +__all__ = [ + "OmniARScheduler", + "OmniGenerationScheduler", + "OmniNewRequestData", +] diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2af96e718ce416ab3263c598eac8d38c34f56185 --- /dev/null +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -0,0 +1,644 @@ +from __future__ import annotations + +import importlib +from collections import defaultdict +from dataclasses import asdict, dataclass +from time import time +from typing import Any + +from vllm.compilation.cuda_graph import CUDAGraphStat +from vllm.distributed.kv_events import KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler +from vllm.v1.core.sched.utils import remove_all +from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.metrics.perf import PerfStats +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats + +from vllm_omni.core.sched.output import OmniSchedulerOutput +from vllm_omni.distributed.omni_connectors.adapter import get_chunk, put_chunk +from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory +from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec + +logger = init_logger(__name__) + + +@dataclass +class KVCacheTransferData: + request_id: str + layer_blocks: dict[str, Any] + block_ids: list[int] + metadata: dict[str, Any] + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +class OmniARScheduler(VLLMScheduler): + """ + OmniARScheduler: Scheduler for vLLM-Omni multimodal processing. + + This scheduler extends vLLM's scheduler to support multimodal and + non-autoregressive processing with additional fields and methods + specific to vLLM-Omni. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Track requests that need KV cache transfer when finished + # Value is {"seq_len": int, "block_ids": list[int]} + self.requests_needing_kv_transfer: dict[str, dict[str, Any]] = {} + + # Track requests waiting for KV transfer (blocks not freed yet) + self.waiting_for_transfer_free: set[str] = set() + + # Track ACTIVE transfers (submitted to runner but not yet acked via kv_extracted_req_ids) + self.active_kv_transfers: set[str] = set() + + # [Omni] Pre-parse KV transfer criteria + self.kv_transfer_criteria = self._get_kv_transfer_criteria() + + # Track requests that have already triggered prefill transfer to avoid duplicates + self.transfer_triggered_requests: set[str] = set() + model_config = self.vllm_config.model_config + self.omni_connector = None + if model_config.async_chunk: + connector_config = model_config.stage_connector_config + connector_specs = ConnectorSpec( + name=connector_config.get("name", "SharedMemoryConnector"), + extra=connector_config.get("extra", {}), + ) + self.omni_connector = OmniConnectorFactory.create_connector(connector_specs) + + custom_process_next_stage_input_func = getattr( + self.vllm_config.model_config, "custom_process_next_stage_input_func", None + ) + if custom_process_next_stage_input_func: + module_path, func_name = custom_process_next_stage_input_func.rsplit(".", 1) + module = importlib.import_module(module_path) + self.custom_process_next_stage_input_func = getattr(module, func_name) + + self.stage_id = getattr(self.vllm_config.model_config, "stage_id", None) + + def _get_kv_transfer_criteria(self) -> dict | None: + # Note: vllm_config is available in Scheduler after super().__init__ + if not hasattr(self, "vllm_config"): + return None + + omni_kv_config = getattr(self.vllm_config.model_config, "omni_kv_config", None) + if omni_kv_config: + if isinstance(omni_kv_config, dict): + return omni_kv_config.get("kv_transfer_criteria", None) + else: + return getattr(omni_kv_config, "kv_transfer_criteria", None) + return None + + def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int]) -> bool: + """ + Check triggers and process side effects (marking transfer). + Returns True if request should be STOPPED. + Returns False if request should continue (even if transfer was triggered). + """ + if not self.kv_transfer_criteria: + return False + + if request.request_id in self.waiting_for_transfer_free: + return False + + criteria_type = self.kv_transfer_criteria.get("type") + + # Universal duplicate check for once semantics + if request.request_id in self.transfer_triggered_requests: + return False + + if criteria_type == "prefill_finished": + if request.num_computed_tokens >= request.num_prompt_tokens: + logger.debug(f"[Omni] Request {request.request_id} triggered prefill_finished transfer (Non-Stop)") + self.transfer_triggered_requests.add(request.request_id) + self._mark_request_for_kv_transfer(request.request_id, request.num_computed_tokens) + + # Return False means "Do NOT stop the request" -> Continue Decoding + return False + + elif criteria_type == "special_token": + target_token_id = self.kv_transfer_criteria.get("token_id") + if target_token_id is not None and target_token_id in new_token_ids: + logger.debug(f"[Omni] Request {request.request_id} triggered special_token criteria (Non-Stop)") + + self.transfer_triggered_requests.add(request.request_id) + + # Calculate precise snapshot length (trim to sentinel) + # Find the FIRST occurrence of the sentinel + try: + idx = new_token_ids.index(target_token_id) + # seq_len = tokens_before_this_step + idx + 1 (include sentinel) + # request.num_computed_tokens already includes ALL new_token_ids + # so we subtract (len(new_token_ids) - (idx + 1)) + tokens_to_exclude = len(new_token_ids) - (idx + 1) + snapshot_len = request.num_computed_tokens - tokens_to_exclude + except ValueError: + snapshot_len = request.num_computed_tokens + + # Trigger Transfer + self._mark_request_for_kv_transfer(request.request_id, snapshot_len) + + # Do NOT stop request + return False + + return False + + def schedule(self) -> SchedulerOutput: # type: ignore[override] + scheduler_output = super().schedule() + try: + # Late import to avoid circulars in some launch modes + from .output import OmniNewRequestData + + # Rewrap base NewRequestData entries with OmniNewRequestData, + # enriching with request-level payloads + new_list = [] + for nr in scheduler_output.scheduled_new_reqs: + req_id = getattr(nr, "req_id", None) + request = self.requests.get(req_id) if req_id else None + # Build omni entry preserving all base fields + omni_nr = OmniNewRequestData( + req_id=nr.req_id, + external_req_id=(getattr(request, "external_req_id", None) if request else None), + prompt_token_ids=nr.prompt_token_ids, + mm_features=nr.mm_features, + sampling_params=nr.sampling_params, + pooling_params=nr.pooling_params, + block_ids=nr.block_ids, + num_computed_tokens=nr.num_computed_tokens, + lora_request=nr.lora_request, + # Enrich with omni payloads from the live request object + prompt_embeds=(getattr(request, "prompt_embeds", None) if request else None), + additional_information=(getattr(request, "additional_information", None) if request else None), + ) + new_list.append(omni_nr) + + scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment] + if self.omni_connector is not None: + get_chunk(self.omni_connector, scheduler_output) + + # Add information about requests needing KV cache transfer + finished_reqs = self.get_finished_requests_needing_kv_transfer() + except Exception: + # If anything goes wrong, leave the original output unchanged + init_logger(__name__).exception("Failed to wrap scheduled_new_reqs with OmniNewRequestData") + finished_reqs = {} + + # Wrap in omni scheduler output to carry transfer metadata. + base_fields = SchedulerOutput.__dataclass_fields__.keys() + base_data = {name: getattr(scheduler_output, name) for name in base_fields} + return OmniSchedulerOutput( + **base_data, + finished_requests_needing_kv_transfer=finished_reqs, + ) + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output + cudagraph_stats: CUDAGraphStat | None = model_runner_output.cudagraph_stats + + perf_stats: PerfStats | None = None + if self.perf_metrics and self.perf_metrics.is_enabled(): + perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output) + + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats: KVConnectorStats | None = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) + if kv_connector_stats and self.connector: + kv_stats = self.connector.get_kv_connector_stats() + if kv_stats: + kv_connector_stats = kv_connector_stats.aggregate(kv_stats) + + failed_kv_load_req_ids = None + if kv_connector_output and kv_connector_output.invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids) + + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + continue + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id) + if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens. + if request.num_computed_tokens > 0: + request.num_computed_tokens -= num_rejected + # If async scheduling, num_output_placeholders also includes + # the scheduled spec tokens count and so is similarly adjusted. + if request.num_output_placeholders > 0: + request.num_output_placeholders -= num_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted, + num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens, + request_id=req_id, + ) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + pooler_output = pooler_outputs[req_index] if pooler_outputs else None + kv_transfer_params = None + status_before_stop = request.status + finish_reason = None + routed_experts = None + + # Check for stop and update request status. + if new_token_ids: + new_token_ids, stopped = self._update_request_with_output(request, new_token_ids) + elif request.pooling_params and pooler_output is not None: + # Pooling stops as soon as there is output. + request.status = RequestStatus.FINISHED_STOPPED + stopped = True + + # If criteria returns True, it means we must STOP the request. + # If criteria returns False, it might have triggered a background + # transfer (e.g. prefill finished / special token) but continues decoding. + if not stopped and self._process_kv_transfer_trigger(request, new_token_ids): + stopped = True + + if stopped: + routed_experts = self._get_routed_experts(request) + + # Capture finish_reason BEFORE _handle_stopped_request, which may + # reset the status to WAITING for streaming requests that continue. + finish_reason = request.get_finished_reason() + finished = self._handle_stopped_request(request) + if finished: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs: + new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) + + if new_token_ids and self.structured_output_manager.should_advance(request): + struct_output_request = request.structured_output_request + assert struct_output_request is not None + assert struct_output_request.grammar is not None + ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids) + if not ok: + logger.warning( + "Unexpected: grammar rejected tokens %s for request %s.", + new_token_ids, + req_id, + ) + + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None or kv_transfer_params or stopped: + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=finish_reason, + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + routed_experts=routed_experts, + num_nans_in_logits=request.num_nans_in_logits, + ) + ) + if self.omni_connector is not None: + custom_process_next_stage_input_func = self.custom_process_next_stage_input_func + put_chunk(self.omni_connector, pooler_output, request, custom_process_next_stage_input_func) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = remove_all(self.running, stopped_running_reqs) + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # [Main] Handle failed KV load requests + if failed_kv_load_req_ids and not self.recompute_kv_load_failures: + requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids] + self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR) + for request in requests: + outputs[request.client_index].append( + EngineCoreOutput( + request_id=request.request_id, + new_token_ids=[], + finish_reason=request.get_finished_reason(), + events=request.take_events(), + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + ) + ) + + # [Omni] Cleanup state for finished requests + for req in stopped_running_reqs: + if req.request_id not in self.waiting_for_transfer_free: + if req.request_id in self.transfer_triggered_requests: + self.transfer_triggered_requests.remove(req.request_id) + if req.request_id in self.active_kv_transfers: + self.active_kv_transfers.remove(req.request_id) + + # Same for preempted + for req in stopped_preempted_reqs: + if req.request_id not in self.waiting_for_transfer_free: + if req.request_id in self.transfer_triggered_requests: + self.transfer_triggered_requests.remove(req.request_id) + if req.request_id in self.active_kv_transfers: + self.active_kv_transfers.remove(req.request_id) + # KV Connector: update state for finished KV Transfers. + if kv_connector_output: + self._update_from_kv_xfer_finished(kv_connector_output) + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()} + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set) + finished_req_ids.clear() + + if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats)) is not None: + # Return stats to only one of the front-ends. + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats + + # This is where we free blocks that were held for transfer + try: + kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None) + if kv_extracted_ids: + for req_id in kv_extracted_ids: + # Mark transfer as finished + if req_id in self.active_kv_transfers: + self.active_kv_transfers.remove(req_id) + logger.debug(f"[Omni] KV Transfer finished for {req_id}") + + if req_id in self.waiting_for_transfer_free: + # Now it's safe to free blocks + req = self.requests.get(req_id) + if req: + self.kv_cache_manager.free(req) + if req_id in self.requests: + del self.requests[req_id] + if req_id in self.transfer_triggered_requests: + self.transfer_triggered_requests.remove(req_id) + if req_id in self.active_kv_transfers: + self.active_kv_transfers.remove(req_id) + + logger.debug(f"Freed blocks for {req_id} after transfer extraction") + self.waiting_for_transfer_free.remove(req_id) + except Exception: + init_logger(__name__).exception("Failed to process finished transfer requests") + + return engine_core_outputs + + def _free_request(self, request: Request) -> dict[str, Any] | None: + # TODO(wzliu)! for offline mode, we should not end process until all data is transferred + """Mark a request as finished and free its resources.""" + + # 1. Standard cleanup parts from base _free_request + delay_free_blocks = False + kv_xfer_params = None + if self.connector is not None: + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + + self.encoder_cache_manager.free(request) + request_id = request.request_id + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) + + # 2. Omni Specific: Check if we need to transfer KV + if self._should_transfer_kv_for_request(request_id): + already_triggered = request_id in self.transfer_triggered_requests + is_active = request_id in self.active_kv_transfers + + if already_triggered: + if is_active: + # It triggered but hasn't finished yet. We MUST wait. + logger.debug(f"[Omni] Request {request_id} finished but transfer is still ACTIVE. Waiting.") + self.waiting_for_transfer_free.add(request_id) + # We do NOT mark for transfer again, just wait. + kv_xfer_params = None # No new transfer params + return kv_xfer_params + else: + logger.debug( + f"[Omni] Request {request_id} finished and transfer no longer ACTIVE (extracted/acked). " + "Freeing immediately." + ) + else: + self.waiting_for_transfer_free.add(request_id) + self._mark_request_for_kv_transfer(request_id, request.num_computed_tokens) + # Return KV transfer metadata so it propagates to RequestOutput + if request_id in self.requests_needing_kv_transfer: + transfer_data = self.requests_needing_kv_transfer[request_id] + kv_xfer_params = { + "past_key_values": transfer_data["block_ids"], + "kv_metadata": {"seq_len": transfer_data["seq_len"], "block_ids": transfer_data["block_ids"]}, + } + # Also update request.additional_information for good measure + add_info = getattr(request, "additional_information", None) + # If additional_information is an AdditionalInformationPayload-like object, + # unpack list_data into a plain dict. + if ( + add_info is not None + and hasattr(add_info, "entries") + and isinstance(getattr(add_info, "entries"), dict) + ): + request.additional_information = { + k: getattr(v, "list_data") + for k, v in getattr(add_info, "entries").items() + if getattr(v, "list_data", None) is not None + } + add_info = request.additional_information + if add_info is None: + request.additional_information = {} + add_info = request.additional_information + if isinstance(add_info, dict): + add_info.update(kv_xfer_params) + + return kv_xfer_params + + # 3. Standard Freeing + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + # Helper to match base class structure if not directly available + # VLLMScheduler has _free_blocks + super()._free_blocks(request) + + def _mark_request_for_kv_transfer(self, req_id: str, seq_len: int) -> None: + """Mark a request as needing KV cache transfer when it finishes.""" + # Avoid duplicate marking (if already pending in queue) + if req_id in self.requests_needing_kv_transfer: + return + + if self._should_transfer_kv_for_request(req_id): + # [Omni] Get block IDs from KVCacheManager + try: + block_ids_tuple = self.kv_cache_manager.get_block_ids(req_id) + if block_ids_tuple and len(block_ids_tuple) > 0: + block_ids = block_ids_tuple[0] + + # [Omni] Fix: Truncate blocks to match seq_len snapshot + # We need to know block_size. Usually in self.cache_config.block_size + # Note: vllm_config might not be directly available, check scheduler_config or cache_config + if hasattr(self, "cache_config") and hasattr(self.cache_config, "block_size"): + block_size = self.cache_config.block_size + elif hasattr(self, "scheduler_config") and hasattr( + self.scheduler_config, "block_size" + ): # Some versions + block_size = self.scheduler_config.block_size + else: + raise ValueError("Block size not found in cache_config or scheduler_config") + + # ceil(seq_len / block_size) + num_blocks = (seq_len + block_size - 1) // block_size + if len(block_ids) > num_blocks: + logger.debug( + f"[Omni] Truncating blocks for {req_id} from {len(block_ids)} " + f"to {num_blocks} (seq_len={seq_len})" + ) + block_ids = block_ids[:num_blocks] + + else: + block_ids = [] + except Exception as e: + init_logger(__name__).warning(f"Failed to get block IDs for {req_id}: {e}") + block_ids = [] + + self.requests_needing_kv_transfer[req_id] = {"seq_len": seq_len, "block_ids": block_ids} + logger.debug(f"Marked request {req_id} for KV cache transfer (len={seq_len}, blocks={len(block_ids)})") + + def _should_transfer_kv_for_request(self, req_id: str) -> bool: + """Determine if a request should trigger KV cache transfer.""" + need_send = False + # Try to read from vLLM Config (where YAML config is typically loaded) + # Check for omni_kv_config attribute + omni_kv_config = getattr(self.vllm_config.model_config, "omni_kv_config", None) + if omni_kv_config: + # omni_kv_config could be an object or a dict + if isinstance(omni_kv_config, dict): + need_send = omni_kv_config.get("need_send_cache", False) + else: + need_send = getattr(omni_kv_config, "need_send_cache", False) + return need_send + + def has_requests(self) -> bool: + """Check if there are any requests to process, including KV transfers.""" + # [Omni] Also check for pending KV transfers + if self.requests_needing_kv_transfer or self.active_kv_transfers or self.waiting_for_transfer_free: + return True + return super().has_requests() + + def has_finished_requests(self) -> bool: + """Check if there are any finished requests (including those needing KV transfer).""" + if self.requests_needing_kv_transfer or self.active_kv_transfers or self.waiting_for_transfer_free: + return True + return super().has_finished_requests() + + def has_unfinished_requests(self) -> bool: + """Check if there are any unfinished requests (including those needing KV transfer).""" + # [Omni] Also check for pending KV transfers to ensure the engine loop continues + # MUST verify waiting_for_transfer_free and active_kv_transfers + # Otherwise engine loop might exit before transfer Ack is received. + if self.requests_needing_kv_transfer or self.active_kv_transfers or self.waiting_for_transfer_free: + return True + return super().has_unfinished_requests() + + def get_finished_requests_needing_kv_transfer(self) -> dict[str, dict]: + """Get and clear the list of requests needing KV cache transfer. + Returns dict: {req_id: {"seq_len": int, "block_ids": list[int]}} + """ + requests = self.requests_needing_kv_transfer.copy() + + # Mark these requests as ACTIVE (sent to runner) + self.active_kv_transfers.update(requests.keys()) + + self.requests_needing_kv_transfer.clear() + return requests diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..50e7b8dbc3b051036f4a7b8d1f8eb76b0ddce0e7 --- /dev/null +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -0,0 +1,463 @@ +import time +from collections import defaultdict + +from vllm.compilation.cuda_graph import CUDAGraphStat +from vllm.distributed.kv_events import KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.request_queue import create_request_queue +from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler +from vllm.v1.core.sched.utils import remove_all +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs +from vllm.v1.metrics.perf import PerfStats +from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats + +from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData +from vllm_omni.distributed.omni_connectors.adapter import get_chunk_for_generation +from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory +from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec +from vllm_omni.outputs import OmniModelRunnerOutput + + +class OmniGenerationScheduler(VLLMScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + model_config = self.vllm_config.model_config + self.omni_connector = None + if model_config.async_chunk: + connector_config = model_config.stage_connector_config + connector_specs = ConnectorSpec( + name=connector_config.get("name", "SharedMemoryConnector"), + extra=connector_config.get("extra", {}), + ) + self.omni_connector = OmniConnectorFactory.create_connector(connector_specs) + self.stage_id = getattr(self.vllm_config.model_config, "stage_id", None) + + def schedule(self) -> SchedulerOutput: + """Diffusion fast path: + - Feed all input tokens of the request at once + (if 0, allocate 1 placeholder token). + - If the token budget cannot be satisfied at once, fall back to the + default vLLM scheduling. + """ + + token_budget = self.max_num_scheduled_tokens + scheduled_timestamp = time.monotonic() + + scheduled_new_reqs: list[Request] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + scheduled_running_reqs: list[Request] = [] + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + scheduled_encoder_inputs: dict[str, list[int]] = {} + cached_prompt_token_ids: dict[str, list[int]] = {} + + # Temporary queue: preserve waiting order, do not disturb non-diffusion requests + skipped_waiting_requests = create_request_queue(self.policy) + req_index = 0 + # OMNI: Track requests that are already finished (e.g., marked by connector) + # These should be removed from running and not scheduled + already_finished_reqs: set[Request] = set() + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + if self.omni_connector is not None: + get_chunk_for_generation(self.omni_connector, request) + + # OMNI: Skip requests that are already finished or not in self.requests + # This can happen when connector marks request as finished + if request.status == RequestStatus.FINISHED_STOPPED or request.request_id not in self.requests: + already_finished_reqs.add(request) + req_index += 1 + continue + + num_computed_tokens = request.num_computed_tokens + required_tokens = max(len(request.prompt_token_ids) - num_computed_tokens, 1) + num_new_tokens = min(required_tokens, token_budget) + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + ) + if new_blocks is None: + # Allocation failed (e.g., VRAM pressure); stop fast path and + # fall back to default scheduling + # Put the current request back to the head of the waiting queue + # Note: the original queue order is preserved + break + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + cached_prompt_token_ids[request.request_id] = request.prompt_token_ids + token_budget -= num_new_tokens + scheduled_running_reqs.append(request) + req_index += 1 + + # OMNI: Remove already finished requests from running queue + if already_finished_reqs: + self.running = remove_all(self.running, already_finished_reqs) + + # Fast path selection and scheduling (treat all as diffusion requests, + # independent of pooling_params) + while self.waiting and token_budget > 0 and len(self.running) < self.max_num_running_reqs: + request = self.waiting.peek_request() + if self.omni_connector is not None: + get_chunk_for_generation(self.omni_connector, request) + + # OMNI: Skip requests that are already finished or not in self.requests + # This can happen when connector marks request as finished + if request.status == RequestStatus.FINISHED_STOPPED or request.request_id not in self.requests: + # Pop the finished request from waiting queue and don't schedule it + self.waiting.pop_request() + continue + + # Uniformly treat as diffusion. A feature flag can be added later + # via config or request tag. + + # Allocate all input tokens for the request in one shot + # (allocate 1 placeholder if zero) + required_tokens = max(len(request.prompt_token_ids), 1) + num_new_tokens = min(required_tokens, token_budget) + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + ) + if new_blocks is None: + # Allocation failed (e.g., VRAM pressure); stop fast path and + # fall back to default scheduling + # Put the current request back to the head of the waiting queue + # Note: the original queue order is preserved + break + + # Officially schedule this request + request = self.waiting.pop_request() + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) + + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + scheduled_new_reqs.append(request) + + # Return skipped waiting requests + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # If fast path scheduled none, fall back to the original scheduling + if not num_scheduled_tokens: + return super().schedule() + + # Compute common prefix blocks (aligned with v1) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) + + # Assemble SchedulerOutput (align with v0.14.0) + if self.use_v2_model_runner: + # No resumed reqs in fast path; pass prefill_token_ids for new reqs. + new_reqs_data = [ + OmniNewRequestData.from_request( + req, + req_to_new_blocks[req.request_id].get_block_ids(), + getattr(req, "_all_token_ids", None), + ) + for req in scheduled_new_reqs + ] + else: + new_reqs_data = [ + OmniNewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + # No running/resumed reqs scheduled in our fast path + cached_reqs_data = self._make_cached_request_data( + running_reqs=scheduled_running_reqs, + resumed_reqs=[], + num_scheduled_tokens=num_scheduled_tokens, + spec_decode_tokens=scheduled_spec_decode_tokens, + req_to_new_blocks=req_to_new_blocks, + ) + + cached_reqs_data = OmniCachedRequestData( + req_ids=cached_reqs_data.req_ids, + resumed_req_ids=cached_reqs_data.resumed_req_ids, + new_token_ids=cached_reqs_data.new_token_ids, + all_token_ids=cached_reqs_data.all_token_ids, + new_block_ids=cached_reqs_data.new_block_ids, + num_computed_tokens=cached_reqs_data.num_computed_tokens, + num_output_tokens=cached_reqs_data.num_output_tokens, + prompt_token_ids=cached_prompt_token_ids, + ) + + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), + preempted_req_ids=set(), + ) + + # Record the request ids scheduled in this step (v0.14.0 behavior). + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + + # KVTransfer: package metadata + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + # EC Connector: package metadata + if self.ec_connector is not None: + ec_meta = self.ec_connector.build_connector_meta(scheduler_output) + scheduler_output.ec_connector_metadata = ec_meta + + # Update internal state (advance num_computed_tokens, free encoder inputs, + # etc.) + self._update_after_schedule(scheduler_output) + + try: + # Rewrap base NewRequestData entries with OmniNewRequestData, + # enriching with request-level payloads + new_list = [] + for nr in scheduler_output.scheduled_new_reqs: + req_id = getattr(nr, "req_id", None) + request = self.requests.get(req_id) if req_id else None + # Build omni entry preserving all base fields + omni_nr = OmniNewRequestData( + req_id=nr.req_id, + external_req_id=(getattr(request, "external_req_id", None) if request else None), + prompt_token_ids=nr.prompt_token_ids, + mm_features=nr.mm_features, + sampling_params=nr.sampling_params, + pooling_params=nr.pooling_params, + block_ids=nr.block_ids, + num_computed_tokens=nr.num_computed_tokens, + lora_request=nr.lora_request, + # Enrich with omni payloads from the live request object + prompt_embeds=(getattr(request, "prompt_embeds", None) if request else None), + additional_information=(getattr(request, "additional_information", None) if request else None), + ) + new_list.append(omni_nr) + + scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment] + except Exception: + # If anything goes wrong, leave the original output unchanged + init_logger(__name__).exception("Failed to wrap scheduled_new_reqs with OmniNewRequestData") + + return scheduler_output + + """ + Scheduler for the diffusion model. + This scheduler is modified to stop the request immediately for the diffusion model. + This is because the diffusion model can generate the final image/audio in one step. + Note: This is just a minimal modification to the original scheduler, + and there should be some further efforts to optimize the scheduler. + The original scheduler is still used for the AR model. + """ + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: OmniModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + """Update the scheduler state based on the model runner output. + + This method is modified to stop the request immediately for the diffusion model. + """ + sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output + + cudagraph_stats: CUDAGraphStat | None = model_runner_output.cudagraph_stats + perf_stats: PerfStats | None = None + if self.perf_metrics and self.perf_metrics.is_enabled(): + perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output) + + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats: KVConnectorStats | None = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) + # Merge connector-side stats (align with v0.14.0) + if kv_connector_stats and self.connector: + kv_stats = self.connector.get_kv_connector_stats() + if kv_stats: + kv_connector_stats = kv_connector_stats.aggregate(kv_stats) + + failed_kv_load_req_ids = None + if kv_connector_output and getattr(kv_connector_output, "invalid_block_ids", None): + failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids) + + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + continue + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id) + if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens. + if request.num_computed_tokens > 0: + request.num_computed_tokens -= num_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted, + ) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + pooler_output = pooler_outputs[req_index] if pooler_outputs else None + status_before_stop = request.status + finish_reason = None + routed_experts = None + + # Diffusion request: completes in one step; mark finished and free resources + if request.status == RequestStatus.FINISHED_STOPPED or ( + self.omni_connector is None and request.num_computed_tokens >= request.num_prompt_tokens + ): + request.status = RequestStatus.FINISHED_STOPPED + # Optional: set a stop_reason for front-end clarity + # (does not affect protocol) + request.stop_reason = request.stop_reason # or "generation_done" + stopped = True + + if stopped: + routed_experts = self._get_routed_experts(request) + finish_reason = request.get_finished_reason() + finished = self._handle_stopped_request(request) + if finished: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs: + new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) + + if new_token_ids and self.structured_output_manager.should_advance(request): + # NOTE: structured_output_request should not be None if + # use_structured_output, we have check above, so safe to ignore + # type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] # noqa: E501 + req_id, new_token_ids + ) + + # spec_token_ids comes from the model runner output + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None or kv_transfer_params or stopped: + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=finish_reason, + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + routed_experts=routed_experts, + num_nans_in_logits=request.num_nans_in_logits, + ) + ) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = remove_all(self.running, stopped_running_reqs) + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # KV Connector: update state for finished KV Transfers. + if kv_connector_output: + self._update_from_kv_xfer_finished(kv_connector_output) + + # Collect and publish KV cache events (align with v0.14.0) + events = self.kv_cache_manager.take_events() + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()} + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set) + finished_req_ids.clear() + + if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats)) is not None: + # Return stats to only one of the front-ends. + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats + + return engine_core_outputs diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py new file mode 100644 index 0000000000000000000000000000000000000000..933ae7629c27b67a1376554e47b3e7011d543e2d --- /dev/null +++ b/vllm_omni/core/sched/output.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass, field + +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.request import Request + +from vllm_omni.engine import AdditionalInformationPayload, PromptEmbedsPayload + + +@dataclass +class OmniNewRequestData(NewRequestData): + """New request data for omni models with embeddings support. + + Extends NewRequestData to include prompt embeddings and additional + information for direct transfer between pipeline stages. + + Args: + prompt_embeds: Optional serialized prompt embeddings payload + additional_information: Optional serialized additional information + dictionary containing tensors or lists + """ + + # Optional serialized prompt embeddings + prompt_embeds: PromptEmbedsPayload | None = None + # Optional external request ID for tracking + external_req_id: str | None = None + # Optional serialized additional information + additional_information: AdditionalInformationPayload | None = None + + @classmethod + def from_request( + cls, + request: Request, + block_ids: tuple[list[int], ...], + prefill_token_ids: list[int] | None = None, + ) -> "OmniNewRequestData": + """Create OmniNewRequestData from a Request object. + + Args: + request: Request object to convert + block_ids: Tuple of block ID lists for KV cache allocation + + Returns: + OmniNewRequestData instance with data from the request + """ + return cls( + req_id=request.request_id, + external_req_id=request.external_req_id, + prompt_token_ids=request.prompt_token_ids, + mm_features=request.mm_features, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + block_ids=block_ids, + num_computed_tokens=request.num_computed_tokens, + lora_request=request.lora_request, + prompt_embeds=request.prompt_embeds, + prefill_token_ids=prefill_token_ids, + additional_information=request.additional_information, + ) + + +@dataclass +class OmniCachedRequestData(CachedRequestData): + """Cached request data for omni models with embeddings support. + + Args: + prompt_token_ids: Mapping from request ID to list of prompt token IDs + """ + + prompt_token_ids: dict[str, list[int]] + + +@dataclass +class OmniSchedulerOutput(SchedulerOutput): + """Scheduler output with omni-specific transfer metadata.""" + + finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict) diff --git a/vllm_omni/diffusion/__init__.py b/vllm_omni/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/diffusion/attention/__init__.py b/vllm_omni/diffusion/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..208f01a7cb5ee04c88d276fec2082cd4e830884b --- /dev/null +++ b/vllm_omni/diffusion/attention/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/attention/backends/__init__.py b/vllm_omni/diffusion/attention/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..208f01a7cb5ee04c88d276fec2082cd4e830884b --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/diffusion/attention/backends/abstract.py b/vllm_omni/diffusion/attention/backends/abstract.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a62bcd9cc67d0679d9a153ccd2e41c953ced14 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/abstract.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, TypeVar + +import torch + +from vllm_omni.platforms import current_omni_platform + + +class AttentionBackend(ABC): + """Abstract class for diffusion attention backends.""" + + accept_output_buffer: bool = False + + @classmethod + def supports_attention_mask(cls) -> bool: + return False + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_supported_head_sizes() -> list[int]: + """Get the list of supported head sizes for this backend.""" + raise NotImplementedError + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + supported_head_sizes = cls.get_supported_head_sizes() + return (not supported_head_sizes) or head_size in supported_head_sizes + + +@dataclass +class AttentionMetadata: + attn_mask: torch.Tensor | None = None + joint_attn_mask: torch.Tensor | None = None + # a joint mask for the joint query, key, and value, depends the joint_strategy + joint_query: torch.Tensor | None = None + # a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_key: torch.Tensor | None = None + # a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_value: torch.Tensor | None = None + # a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy + joint_strategy: str = "front" + # the strategy to joint the query, key, and value, can be "front" or "rear" + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionImpl(ABC, Generic[T]): + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + raise NotImplementedError + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T | None = None, + ) -> torch.Tensor: + """Dispatch to platform-specific forward implementation.""" + if current_omni_platform.is_rocm(): + return self.forward_hip(query, key, value, attn_metadata) + elif current_omni_platform.is_cuda(): + return self.forward_cuda(query, key, value, attn_metadata) + elif current_omni_platform.is_npu(): + return self.forward_npu(query, key, value, attn_metadata) + elif current_omni_platform.is_xpu(): + return self.forward_xpu(query, key, value, attn_metadata) + else: + raise NotImplementedError(f"No forward implementation for platform: {current_omni_platform}") + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T | None = None, + ) -> torch.Tensor: + raise NotImplementedError + + def forward_npu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T | None = None, + ) -> torch.Tensor: + raise NotImplementedError + + def forward_xpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T | None = None, + ) -> torch.Tensor: + raise NotImplementedError + + def forward_hip( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T | None = None, + ) -> torch.Tensor: + # By default, HIP ops are compatible with CUDA ops. + return self.forward_cuda(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..45158e8f8bb751faeb39d94780303ab18833aa0d --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/flash_attn.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) + +logger = init_logger(__name__) + + +class FlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @classmethod + def supports_attention_mask(cls) -> bool: + return True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 96, 128, 192, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionImpl + + +class FlashAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.num_heads = num_heads + self.causal = causal + self.softmax_scale = softmax_scale + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: + """CUDA/ROCm flash attention implementation.""" + # Import flash attention functions with fallback chain from utils/fa.py + # FA3 (fa3_fwd_interface) -> FA3 (flash_attn_interface) -> FA2 (flash_attn) + from vllm_omni.diffusion.attention.backends.utils.fa import ( + HAS_FLASH_ATTN, + _pad_input, + _unpad_input, + _upad_input, + flash_attn_func, + flash_attn_varlen_func, + ) + + if not HAS_FLASH_ATTN: + raise ImportError( + "FlashAttentionBackend requires Flash Attention. " + "Please install one of: fa3-fwd, flash-attention, or flash-attn. " + "Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA" + ) + + query_length = query.size(1) + attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None + # Contains at least one padding token in the sequence + if attention_mask is not None and torch.any(~attention_mask): + assert attention_mask.ndim == 2, "attention_mask must be 2D, (batch_size, seq_len)" + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( + query, key, value, attention_mask, query_length, _unpad_input + ) + + out_unpad = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **{ + "causal": self.causal, + "softmax_scale": self.softmax_scale, + }, + ) + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] + + out = _pad_input(out_unpad, indices_q, query.size(0), query_length) + + else: + out = flash_attn_func( + query, + key, + value, + causal=self.causal, + softmax_scale=self.softmax_scale, + ) + # FA3 may return (out, lse) tuple, FA2 returns just out + if isinstance(out, tuple): + out = out[0] + return out + + def forward_npu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: + """NPU attention implementation using mindiesd.""" + try: + from mindiesd import attention_forward + except ImportError: + raise ImportError( + "FlashAttentionBackend NPU implementation requires MindIE-SD. " + "Please install MindIE-SD to enable NPU attention support. " + "For installation details, see https://gitcode.com/Ascend/MindIE-SD" + "Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA" + ) + + attention_mask = attn_metadata.attn_mask if attn_metadata else None + output = attention_forward( + query, + key, + value, + attn_mask=attention_mask, + opt_mode="manual", + op_type="fused_attn_score", + layout="BNSD", + ) + return output diff --git a/vllm_omni/diffusion/attention/backends/registry.py b/vllm_omni/diffusion/attention/backends/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..a77223d16783de7bc5390234c1ac46cb8975b42d --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/registry.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Diffusion attention backend registry. + +This module provides an enum-based registry for diffusion attention backends, +similar to vLLM's AttentionBackendEnum. Each backend registers its class path, +and platforms can override or extend backends using register_backend(). +""" + +from collections.abc import Callable +from enum import Enum, EnumMeta +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.utils.import_utils import resolve_obj_by_qualname + +if TYPE_CHECKING: + from vllm_omni.diffusion.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +class _DiffusionBackendEnumMeta(EnumMeta): + """Metaclass for DiffusionAttentionBackendEnum to provide better error messages.""" + + def __getitem__(cls, name: str) -> "DiffusionAttentionBackendEnum": + """Get backend by name with helpful error messages.""" + try: + return super().__getitem__(name) # type: ignore[return-value] + except KeyError: + members = list(cls.__members__.keys()) + valid_backends = ", ".join(members) + raise ValueError( + f"Unknown diffusion attention backend: '{name}'. Valid options are: {valid_backends}" + ) from None + + +class DiffusionAttentionBackendEnum(Enum, metaclass=_DiffusionBackendEnumMeta): + """Enumeration of all supported diffusion attention backends. + + The enum value is the default class path, but this can be overridden + at runtime using register_backend(). + + To get the actual backend class (respecting overrides), use: + backend.get_class() + + Example: + # Get backend class + backend = DiffusionAttentionBackendEnum.FLASH_ATTN + backend_cls = backend.get_class() + + # Register custom backend + @register_diffusion_backend(DiffusionAttentionBackendEnum.CUSTOM) + class MyCustomBackend: + ... + """ + + # Common backends (available on most platforms) + FLASH_ATTN = "vllm_omni.diffusion.attention.backends.flash_attn.FlashAttentionBackend" + TORCH_SDPA = "vllm_omni.diffusion.attention.backends.sdpa.SDPABackend" + SAGE_ATTN = "vllm_omni.diffusion.attention.backends.sage_attn.SageAttentionBackend" + + def get_path(self, include_classname: bool = True) -> str: + """Get the class path for this backend (respects overrides). + + Returns: + The fully qualified class path string + + Raises: + ValueError: If backend has empty path and is not registered + """ + path = _DIFFUSION_ATTN_OVERRIDES.get(self, self.value) + if not path: + raise ValueError( + f"Backend {self.name} must be registered before use. " + f"Use register_diffusion_backend(DiffusionAttentionBackendEnum.{self.name}, " + f"'your.module.YourClass')" + ) + if not include_classname: + path = path.rsplit(".", 1)[0] + return path + + def get_class(self) -> "type[AttentionBackend]": + """Get the backend class (respects overrides). + + Returns: + The backend class + + Raises: + ImportError: If the backend class cannot be imported + ValueError: If backend has empty path and is not registered + """ + return resolve_obj_by_qualname(self.get_path()) + + def is_overridden(self) -> bool: + """Check if this backend has been overridden. + + Returns: + True if the backend has a registered override + """ + return self in _DIFFUSION_ATTN_OVERRIDES + + def clear_override(self) -> None: + """Clear any override for this backend, reverting to the default.""" + _DIFFUSION_ATTN_OVERRIDES.pop(self, None) + + +# Override registry +_DIFFUSION_ATTN_OVERRIDES: dict[DiffusionAttentionBackendEnum, str] = {} + + +def register_diffusion_backend( + backend: DiffusionAttentionBackendEnum, + class_path: str | None = None, +) -> Callable[[type], type]: + """Register or override a diffusion backend implementation. + + Args: + backend: The DiffusionAttentionBackendEnum member to register + class_path: Optional class path. If not provided and used as + decorator, will be auto-generated from the class. + + Returns: + Decorator function if class_path is None, otherwise a no-op + + Examples: + # Override an existing backend + @register_diffusion_backend(DiffusionAttentionBackendEnum.FLASH_ATTN) + class MyCustomFlashAttn: + ... + + # Override an existing backend (e.g., ASCEND_ATTN) + @register_diffusion_backend(DiffusionAttentionBackendEnum.ASCEND_ATTN) + class CustomAscendAttentionBackend: + ... + + # Direct registration + register_diffusion_backend( + DiffusionAttentionBackendEnum.CUSTOM, + "my.module.MyCustomBackend" + ) + """ + + def decorator(cls: type) -> type: + _DIFFUSION_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" + return cls + + if class_path is not None: + _DIFFUSION_ATTN_OVERRIDES[backend] = class_path + return lambda x: x + + return decorator diff --git a/vllm_omni/diffusion/attention/backends/ring/__init__.py b/vllm_omni/diffusion/attention/backends/ring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a3170408871aff475c6741479a0000f3257631 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/__init__.py @@ -0,0 +1 @@ +# Ring attention backend components diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_globals.py b/vllm_omni/diffusion/attention/backends/ring/ring_globals.py new file mode 100644 index 0000000000000000000000000000000000000000..d80d81fb45113094e5a24a2274e44a0c6ffa824e --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_globals.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + +# test if flash_attn (FA2) is available +try: + import flash_attn # noqa: F401 + from flash_attn.flash_attn_interface import _flash_attn_forward # noqa: F401 + + HAS_FLASH_ATTN = True +except (ImportError, ModuleNotFoundError): + HAS_FLASH_ATTN = False + +# FA3 detection: try multiple sources (forward only, no backward needed for inference) +# Source 1: flash_attn_interface (from flash-attention source build) +# Source 2: fa3_fwd_interface (from fa3-fwd PyPI package, supports Ampere/Ada/Hopper) +# Note: FA3 high-level API may or may not return softmax_lse depending on version. +# For Ring Attention which requires LSE, we fall back to low-level API if needed. +HAS_FA3 = False +fa3_fwd_func = None # Low-level forward function (_flash_attn_forward) +fa3_attn_func = None # High-level attention function (flash_attn_func) + +# Try flash_attn_interface first (from flash-attention source build) +try: + from flash_attn_interface import _flash_attn_forward as fa3_fwd_func # noqa: F401 + from flash_attn_interface import flash_attn_func as fa3_attn_func # noqa: F401 + + HAS_FA3 = True +except (ImportError, ModuleNotFoundError): + pass + +# Fallback: try fa3_fwd_interface (PyPI package, supports Ampere/Ada/Hopper) +if not HAS_FA3: + try: + from fa3_fwd_interface import _flash_attn_forward as fa3_fwd_func # noqa: F401 + from fa3_fwd_interface import flash_attn_func as fa3_attn_func # noqa: F401 + + HAS_FA3 = True + except (ImportError, ModuleNotFoundError): + pass + +# Legacy aliases for backward compatibility +HAS_FLASH_ATTN_HOPPER = HAS_FA3 +flash_attn_forward_hopper = fa3_fwd_func +flash3_attn_func = fa3_attn_func + +try: + from flashinfer.prefill import single_prefill_with_kv_cache # noqa: F401 + + HAS_FLASHINFER = True +except (ImportError, ModuleNotFoundError): + HAS_FLASHINFER = False + +try: + import aiter # noqa: F401 + from aiter import flash_attn_func as flash_attn_func_aiter # noqa: F401 + + HAS_AITER = True +except (ImportError, ModuleNotFoundError): + HAS_AITER = False + +try: + import sageattention # noqa: F401 + + HAS_SAGE_ATTENTION = True +except (ImportError, ModuleNotFoundError): + HAS_SAGE_ATTENTION = False + +try: + import spas_sage_attn # noqa: F401 + + HAS_SPARSE_SAGE_ATTENTION = True +except (ImportError, ModuleNotFoundError): + HAS_SPARSE_SAGE_ATTENTION = False + +try: + import torch_npu # noqa: F401 + + HAS_NPU = True +except (ImportError, ModuleNotFoundError): + HAS_NPU = False diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_kernels.py b/vllm_omni/diffusion/attention/backends/ring/ring_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..ad61f9a2022b2e657dd8b06790cff08ac4634142 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_kernels.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + +import math + +import torch + +from .ring_globals import ( + HAS_AITER, + HAS_FA3, + HAS_FLASH_ATTN, + HAS_FLASHINFER, + fa3_fwd_func, +) + +_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention +_scaled_dot_product_efficient_attention = torch.ops.aten._scaled_dot_product_efficient_attention + +try: + import torch_musa # noqa: F401 + + _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_attention_flash_musa + _scaled_dot_product_efficient_attention = None +except ModuleNotFoundError: + pass + +if HAS_AITER: + from aiter import flash_attn_func as flash_attn_func_aiter + +if HAS_FLASH_ATTN: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward + +if HAS_FLASHINFER: + from flashinfer.prefill import single_prefill_with_kv_cache + + _LOG2_E = math.log2(math.e) + + +def pytorch_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=True, + window_size=(-1, -1), + softcap=None, + alibi_slopes=None, + return_softmax=False, + op_type="efficient", +): + assert op_type in ["flash", "efficient"], f"Invalid op_type: {op_type}" + """ + q shape (bs, seqlen, nhead, hs) + k shape (bs, seqlen, nhead, hs) + v shape (bs, seqlen, nhead, hs) + """ + # Fallback logic: Flash Attention does not support float32. + # If op_type is 'flash' but dtype is float32, force 'efficient'. + if op_type == "flash" and q.dtype == torch.float32: + op_type = "efficient" + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if op_type == "flash": + out, lse = _scaled_dot_product_flash_attention( + q, + k, + v, + dropout_p=dropout_p, + is_causal=causal, + scale=softmax_scale, + )[:2] + elif op_type == "efficient": + out, lse = _scaled_dot_product_efficient_attention( + q, + k, + v, + attn_bias=None, + compute_log_sumexp=True, + dropout_p=dropout_p, + is_causal=causal, + scale=softmax_scale, + )[:2] + else: + raise ValueError(f"Invalid op_type: {op_type}") + + out = out.transpose(1, 2) + lse = lse.to(q.dtype) + + return out, lse + + +def flash_attn_forward( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=None, + alibi_slopes=None, + return_softmax=False, +): + assert HAS_FLASH_ATTN, "FlashAttention is not available" + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if flash_attn.__version__ < "2.6.3": + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + else: + block_out, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + return block_out, block_lse + + +def fa3_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax): + """FA3 forward pass for inference. + + FA3 supports Ampere, Ada, and Hopper GPUs. Dropout is ignored since FA3 is inference-only. + Uses low-level API (_flash_attn_forward) which always returns softmax_lse, + required for Ring Attention's correct accumulation. + """ + assert HAS_FA3, "FA3 is not available" + assert fa3_fwd_func is not None, "FA3 low-level API (fa3_fwd_func) not available" + + # Low-level API always returns (out, softmax_lse, S_dmask, rng_state) + out, softmax_lse, *_ = fa3_fwd_func( + q, + k, + v, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0] if window_size else -1, + window_size_right=window_size[1] if window_size else -1, + softcap=softcap if softcap else 0.0, + ) + + return out, softmax_lse + + +# Legacy alias for backward compatibility +flash_attn3_func_forward = fa3_forward + + +def flash_attn_forward_aiter( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=None, + alibi_slopes=None, + return_softmax=False, +): + assert HAS_AITER, "Aiter is not available" + block_out, block_lse = flash_attn_func_aiter( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_lse=True, + ) + + return block_out, block_lse + + +def flashinfer_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + softmax_scale: float | None = None, + causal: bool = False, + window_size: tuple[int, int] = (-1, -1), + softcap: float | None = None, + alibi_slopes: torch.Tensor | None = None, + return_softmax: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + assert HAS_FLASHINFER, "FlashInfer is not available" + if q.ndim == 4: + if q.shape[0] > 1: + raise ValueError("batch size > 1 is not supported") + out, lse = single_prefill_with_kv_cache( + q[0], + k[0], + v[0], + sm_scale=softmax_scale, + causal=causal, + logits_soft_cap=softcap, + window_left=window_size[0], + return_lse=True, + ) + lse = lse.transpose(0, 1) + out, lse = out.unsqueeze(0), lse.unsqueeze(0) + elif q.ndim == 3: + out, lse = single_prefill_with_kv_cache( + q, + k, + v, + sm_scale=softmax_scale, + causal=causal, + logits_soft_cap=softcap, + window_left=window_size[0], + return_lse=True, + ) + lse = lse.transpose(0, 1) + else: + raise ValueError(f"Invalid input shape: {q.shape}") + lse = lse / _LOG2_E + return out, lse diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_selector.py b/vllm_omni/diffusion/attention/backends/ring/ring_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..77189c61b1c6682bed0b8ede0da976bb9be00474 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_selector.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + +from collections.abc import Callable +from enum import Enum +from functools import partial + +import torch + +from .ring_globals import ( + HAS_SAGE_ATTENTION, + HAS_SPARSE_SAGE_ATTENTION, +) +from .ring_kernels import ( + flash_attn3_func_forward, + flash_attn_forward, + flash_attn_forward_aiter, + flashinfer_attn_forward, + pytorch_attn_forward, +) + +if HAS_SAGE_ATTENTION: + import sageattention + +if HAS_SPARSE_SAGE_ATTENTION: + from spas_sage_attn.autotune import SparseAttentionMeansim + + +class AttnType(Enum): + AITER = "aiter" + FA = "fa" + FA3 = "fa3" + FLASHINFER = "flashinfer" + TORCH = "torch" + SAGE_AUTO = "sage_auto" + SAGE_FP16 = "sage_fp16" + SAGE_FP16_TRITON = "sage_fp16_triton" + SAGE_FP8 = "sage_fp8" + SAGE_FP8_SM90 = "sage_fp8_sm90" + SPARSE_SAGE = "sparse_sage" + + @classmethod + def from_string(cls, s: str): + for member in cls: + if member.value == s: + return member + raise ValueError(f"'{s}' is not a valid {cls.__name__}") + + +def select_flash_attn_impl( + impl_type: AttnType, + stage: str = "fwd-only", + attn_processor: torch.nn.Module | None = None, +) -> Callable[..., tuple[torch.Tensor, torch.Tensor | None]]: + """Select attention implementation for forward pass (inference only). + + Args: + impl_type: The attention implementation type. + stage: Must be "fwd-only" (backward not supported for inference). + attn_processor: Optional custom attention processor. + + Returns: + Callable[..., tuple[torch.Tensor, torch.Tensor | None]]: The attention + forward function for the specified implementation. + """ + if stage != "fwd-only": + raise ValueError(f"Only 'fwd-only' stage is supported for inference. Got: {stage}") + + if impl_type == AttnType.AITER: + return flash_attn_forward_aiter + + elif impl_type == AttnType.FA: + return flash_attn_forward + + elif impl_type == AttnType.FA3: + return flash_attn3_func_forward + + elif impl_type == AttnType.FLASHINFER: + return flashinfer_attn_forward + + elif impl_type == AttnType.TORCH: + return pytorch_attn_forward + + elif impl_type == AttnType.SAGE_AUTO: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn, + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP16: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp16_cuda, + pv_accum_dtype="fp32", + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP16_TRITON: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp16_triton, + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP8: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp8_cuda, + pv_accum_dtype="fp32+fp32", + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SAGE_FP8_SM90: + if not HAS_SAGE_ATTENTION: + raise ImportError("SageAttention is not available!") + return partial( + sageattention.sageattn_qk_int8_pv_fp8_cuda_sm90, + pv_accum_dtype="fp32+fp32", + tensor_layout="NHD", + return_lse=True, + ) + + elif impl_type == AttnType.SPARSE_SAGE: + if not HAS_SPARSE_SAGE_ATTENTION: + raise ImportError("SparseSageAttention is not available!") + if not isinstance(attn_processor, SparseAttentionMeansim): + raise ImportError("SparseSageAttention is only available with a SparseAttentionProcessor class passed in") + + def fn(q, k, v, causal=False, softmax_scale=None, *args, **kwargs): + return ( + attn_processor( + q, + k, + v, + is_causal=causal, + scale=softmax_scale, + tensor_layout="NHD", + ), + None, + ) + + return fn + + elif attn_processor is not None: + return attn_processor + + else: + raise ValueError(f"Unknown flash attention implementation: {impl_type}") diff --git a/vllm_omni/diffusion/attention/backends/ring/ring_utils.py b/vllm_omni/diffusion/attention/backends/ring/ring_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c256f62cbd9cb423f4197f5f057675b6ff754e85 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring/ring_utils.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + + +import torch +import torch.nn.functional as F + +__all__ = ["update_out_and_lse", "flatten_varlen_lse", "unflatten_varlen_lse"] + + +# Remove torch.jit.script for debugging and flexible shape handling +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + block_out = block_out.to(torch.float32) + + B, S, H, D = out.shape + + # --- Shape Correction Logic for block_lse --- + # Goal: block_lse should be (B, S, H, 1) to match out (B, S, H, D) + + # Debug info + # print(f"DEBUG _update: out={out.shape}, block_lse={block_lse.shape}") + + # Case 0: If block_lse is already 4D, check if it matches + if block_lse.dim() == 4: + if block_lse.shape[1] == S and block_lse.shape[2] == H: + pass # Good + elif block_lse.shape[1] == H and block_lse.shape[2] == S: + block_lse = block_lse.transpose(1, 2) + elif block_lse.shape[1] == H and block_lse.shape[2] >= S: # Padding case + block_lse = block_lse[:, :, :S, :].transpose(1, 2) + # If shape is (B, H, S, 1) but expected (B, S, H, 1) because out is (B, S, H, D) + elif block_lse.shape[1] == H and block_lse.shape[2] == S and block_lse.shape[3] == 1: + block_lse = block_lse.transpose(1, 2) + + # Case 1: block_lse is 3D (B, H, S) or (B, S, H) or (B, ?, ?) + elif block_lse.dim() == 3: + # Check for (B, H, S) - Standard SDPA/FA output + if block_lse.shape[1] == H and block_lse.shape[2] == S: + block_lse = block_lse.transpose(1, 2).unsqueeze(-1) + + # Check for (B, S, H) + elif block_lse.shape[1] == S and block_lse.shape[2] == H: + block_lse = block_lse.unsqueeze(-1) + + # Check for Padding: (B, H, S_pad) where S_pad >= S + elif block_lse.shape[1] == H and block_lse.shape[2] >= S: + # print(f"DEBUG: Trimming padding from lse. {block_lse.shape} -> S={S}") + block_lse = block_lse[:, :, :S].transpose(1, 2).unsqueeze(-1) + + # Check for weird case: (B, S, H_pad) ? Unlikely for LSE but possible + elif block_lse.shape[1] == S and block_lse.shape[2] >= H: + block_lse = block_lse[:, :, :H].unsqueeze(-1) + + # Check for flipped weird case: (B, S_pad, H) + elif block_lse.shape[1] >= S and block_lse.shape[2] == H: + block_lse = block_lse[:, :S, :].unsqueeze(-1) + + # --- Shape Correction for lse (internal state) --- + # Ensure lse matches block_lse's corrected shape (B, S, H, 1) + if lse.shape != block_lse.shape: + # If lse was initialized with wrong shape, try to fix it + if lse.dim() == 4 and lse.shape[1] == block_lse.shape[2] and lse.shape[2] == block_lse.shape[1]: + lse = lse.transpose(1, 2) + elif lse.shape[1] >= S: # slice if lse was initialized with padding + lse = lse[:, :S, :, :] + + # Final check + if lse.shape != block_lse.shape: + # Force broadcast if possible? + pass + + try: + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + except RuntimeError as e: + print(f"ERROR in _update_out_and_lse: {e}") + print(f"out: {out.shape}, lse: {lse.shape}") + print(f"block_out: {block_out.shape}, block_lse: {block_lse.shape}") + # raise e + raise e + + return out, lse + + +def update_out_and_lse( + out: torch.Tensor | None, + lse: torch.Tensor | None, + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + + out = block_out.to(torch.float32) + + # Initialize LSE with robust logic (same as _update) + B, D1, D2, D3 = out.shape + + S_guess = D1 + H_guess = D2 + + if block_lse.dim() == 3: + if block_lse.shape[1] == H_guess and block_lse.shape[2] == S_guess: + lse = block_lse.transpose(1, 2).unsqueeze(-1) + elif block_lse.shape[1] == S_guess and block_lse.shape[2] == H_guess: + lse = block_lse.unsqueeze(-1) + elif block_lse.shape[1] == H_guess and block_lse.shape[2] >= S_guess: # Padding + lse = block_lse[:, :, :S_guess].transpose(1, 2).unsqueeze(-1) + elif block_lse.shape[1] == S_guess and block_lse.shape[2] >= H_guess: # Padding/Weird + lse = block_lse[:, :, :H_guess].unsqueeze(-1) + elif block_lse.shape[1] >= S_guess and block_lse.shape[2] == H_guess: + lse = block_lse[:, :S_guess, :].unsqueeze(-1) + + # Reverse case: What if out is (B, H, S, D) so S=D2, H=D1? + elif block_lse.shape[1] == D1 and block_lse.shape[2] >= D2: # Matches (H, S) + # Then out is (B, H, S, D). We should transpose out! + out = out.transpose(1, 2) + lse = block_lse[:, :, :D2].transpose(1, 2).unsqueeze(-1) # (B, S, H, 1) + + else: + # Fallback + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + else: + # Case 0: If block_lse is already 4D, check if it matches + if block_lse.dim() == 4: + if block_lse.shape[1] == S_guess and block_lse.shape[2] == H_guess: + lse = block_lse + elif block_lse.shape[1] == H_guess and block_lse.shape[2] == S_guess: + lse = block_lse.transpose(1, 2) + elif block_lse.shape[1] == H_guess and block_lse.shape[2] >= S_guess: # Padding case + lse = block_lse[:, :, :S_guess, :].transpose(1, 2) + elif block_lse.shape[1] == D1 and block_lse.shape[2] >= D2: # Matches (H, S) + # Then out is (B, H, S, D). We should transpose out! + out = out.transpose(1, 2) + lse = block_lse[:, :, :D2].transpose(1, 2) # (B, S, H, 1) + else: + lse = block_lse + else: + lse = block_lse + + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() diff --git a/vllm_omni/diffusion/attention/backends/ring_flash_attn.py b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..dd27f88a8c1bd73905bc4c4046355fe9e26dd350 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + + +import torch + +from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType, select_flash_attn_impl +from vllm_omni.diffusion.attention.backends.ring.ring_utils import update_out_and_lse +from vllm_omni.diffusion.distributed.comm import RingComm + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + attn_type: AttnType = AttnType.FA, + attn_processor=None, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", +): + # Validate causal + joint_strategy combination + # When causal=True and joint_strategy="rear", the causal mask would incorrectly + # prevent local query tokens from attending to joint key tokens (which are + # concatenated at the end). This breaks the semantics where joint tokens + # (e.g., text conditioning) should be visible to all local tokens. + if causal and joint_tensor_key is not None and joint_strategy == "rear": + raise ValueError( + "joint_strategy='rear' is not compatible with causal=True in Ring Attention. " + "When using causal attention with joint tokens, use joint_strategy='front' " + "to ensure joint tokens act as a visible prefix for all local tokens. " + "With 'rear' strategy, the causal mask would incorrectly block local tokens " + "from seeing the joint tokens." + ) + + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + # Check and adjust q, k, v to be contiguous + if not q.is_contiguous(): + q = q.contiguous() + if not k.is_contiguous(): + k = k.contiguous() + if not v.is_contiguous(): + v = v.contiguous() + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor + next_v: torch.Tensor + next_k = comm.send_recv(k) + next_v = comm.send_recv(v) + comm.commit() + + if not causal or step <= comm.rank: + step_k = k + step_v = v + if step == 0 and joint_tensor_key is not None: + if joint_strategy == "front": + step_k = torch.cat([joint_tensor_key, step_k], dim=1) + step_v = torch.cat([joint_tensor_value, step_v], dim=1) + else: + step_k = torch.cat([step_k, joint_tensor_key], dim=1) + step_v = torch.cat([step_v, joint_tensor_value], dim=1) + + fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor) + block_out, block_lse = fn( + q, + step_k, + step_v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal and step == 0, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + # Ensure block_out is contiguous if needed, though usually it is from FA + + if attn_type == AttnType.SPARSE_SAGE: + out, lse = block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + if attn_type != AttnType.SPARSE_SAGE: + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +class RingFlashAttnFunc(torch.autograd.Function): + """Ring Flash Attention autograd function (inference only, no backward).""" + + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + group, + attn_type, + attn_processor, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=False, + attn_type=attn_type, + attn_processor=attn_processor, + joint_tensor_key=joint_tensor_key, + joint_tensor_value=joint_tensor_value, + joint_strategy=joint_strategy, + ) + return out if not return_softmax else (out, softmax_lse, None) + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + None, # attn_processor + None, # joint_tensor_key + None, # joint_tensor_value + "front", # joint_strategy + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + None, # attn_processor + None, # joint_tensor_key + None, # joint_tensor_value + "front", # joint_strategy + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, + attn_processor=None, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, None]: + """Ring Attention forward pass using Flash Attention backend. + + Implements Ring Attention with sequence parallelism using a ring-based P2P + communication pattern. The sequence dimension is sharded across devices, and + Key/Value blocks are circulated through the ring to accumulate attention results. + + Args: + q (torch.Tensor): Query tensor of shape (batch, seq_len, num_heads, head_dim). + Sequence dimension is sharded across the ring group. + k (torch.Tensor): Key tensor of shape (batch, seq_len, num_heads, head_dim). + Sequence dimension is sharded across the ring group. + v (torch.Tensor): Value tensor of shape (batch, seq_len, num_heads, head_dim). + Sequence dimension is sharded across the ring group. + dropout_p (float): Dropout probability. Defaults to 0.0. + softmax_scale (float | None): Scaling factor for softmax. + If None, computed as head_dim^(-0.5). + causal (bool): Whether to apply causal masking. Defaults to False. + window_size (tuple[int, int]): Sliding window size for attention. + (-1, -1) means no windowing. + softcap (float): Soft capping value for attention logits. Defaults to 0.0. + alibi_slopes (torch.Tensor | None): ALiBi slopes for positional bias. + Not supported. + deterministic (bool): Whether to use deterministic algorithms. + Defaults to False. + return_attn_probs (bool): If True, returns (out, softmax_lse, None). + Defaults to False. + group (ProcessGroup | None): Process group for ring communication. + Defaults to None. + attn_type (AttnType): Flash Attention implementation type + (AttnType.FA, AttnType.FA3, etc.). + attn_processor (Callable | None): Custom attention processor for sparse + attention. Defaults to None. + joint_tensor_key (torch.Tensor | None): Additional key tensor for joint + attention (e.g., text + image). Concatenated only at step=0. + Defaults to None. + joint_tensor_value (torch.Tensor | None): Additional value tensor for + joint attention (e.g., text + image). Concatenated only at step=0. + Defaults to None. + joint_strategy (str): Concatenation strategy ("front" or "back"). + Defaults to "front". + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, None]]: + - If return_attn_probs is False: Output tensor (batch, seq_len, num_heads, head_dim). + - If return_attn_probs is True: A tuple (out, softmax_lse, None). + """ + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + attn_processor, + joint_tensor_key, + joint_tensor_value, + joint_strategy, + ) diff --git a/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..482cbc9f89233b482146559a02d53136ed458771 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024, Jiarui Fang. +# Adapted from https://github.com/feifeibear/long-context-attention + +# adapted from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py +# Copyright 2024 The HuggingFace Inc. team and Jiarui Fang. + + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.ring.ring_kernels import pytorch_attn_forward +from vllm_omni.diffusion.attention.backends.ring.ring_utils import update_out_and_lse +from vllm_omni.diffusion.distributed.comm import RingComm + +logger = init_logger(__name__) + + +def ring_pytorch_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + op_type="efficient", + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", +): + return RingAttentionFunc.apply( + group, + q, + k, + v, + softmax_scale, + causal, + op_type, + joint_tensor_key, + joint_tensor_value, + joint_strategy, + ) + + +class RingAttentionFunc(torch.autograd.Function): + """Ring Attention autograd function using PyTorch SDPA (inference only, no backward).""" + + @staticmethod + def forward( + ctx, + group, + q, + k, + v, + sm_scale, + is_causal, + op_type, + joint_tensor_key=None, + joint_tensor_value=None, + joint_strategy="front", + ): + # Validate causal + joint_strategy combination + # When causal=True and joint_strategy="rear", the causal mask would incorrectly + # prevent local query tokens from attending to joint key tokens (which are + # concatenated at the end). This breaks the semantics where joint tokens + # (e.g., text conditioning) should be visible to all local tokens. + if is_causal and joint_tensor_key is not None and joint_strategy == "rear": + raise ValueError( + "joint_strategy='rear' is not compatible with causal=True in Ring Attention. " + "When using causal attention with joint tokens, use joint_strategy='front' " + "to ensure joint tokens act as a visible prefix for all local tokens. " + "With 'rear' strategy, the causal mask would incorrectly block local tokens " + "from seeing the joint tokens." + ) + + comm = RingComm(group) + # Ensure tensors are contiguous for P2P communication + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + out, lse = None, None + next_k, next_v = None, None + + if sm_scale is None: + sm_scale = q.shape[-1] ** -0.5 + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k = comm.send_recv(k) + next_v = comm.send_recv(v) + comm.commit() + + if not is_causal or step <= comm.rank: + step_k = k + step_v = v + if step == 0 and joint_tensor_key is not None: + if joint_strategy == "front": + step_k = torch.cat([joint_tensor_key, step_k], dim=1) + step_v = torch.cat([joint_tensor_value, step_v], dim=1) + else: + step_k = torch.cat([step_k, joint_tensor_key], dim=1) + step_v = torch.cat([step_v, joint_tensor_value], dim=1) + + block_out, block_lse = pytorch_attn_forward( + q, + step_k, + step_v, + softmax_scale=sm_scale, + causal=is_causal and step == 0, + op_type=op_type, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + + return out diff --git a/vllm_omni/diffusion/attention/backends/sage_attn.py b/vllm_omni/diffusion/attention/backends/sage_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..af34a063393904888456de94ca14bf18022e315c --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/sage_attn.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) + +logger = init_logger(__name__) + +try: + from sageattention import sageattn +except ImportError: + logger.warning( + "SageAttentionBackend is not available. You may install sage-attention" + " by pip install git+https://github.com/thu-ml/SageAttention.git" + ) + raise ImportError + +# TODO add sage3 attention backend + + +class SageAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "SAGE_ATTN" + + @staticmethod + def get_impl_cls() -> type["SageAttentionImpl"]: + return SageAttentionImpl + + +class SageAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: + output = sageattn( + query, + key, + value, + tensor_layout="NHD", + is_causal=self.causal, + sm_scale=self.softmax_scale, + ) + return output diff --git a/vllm_omni/diffusion/attention/backends/sdpa.py b/vllm_omni/diffusion/attention/backends/sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe460e3db4df7a3f0449a55081d158eef01df78 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/sdpa.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) + +logger = init_logger(__name__) + + +def _maybe_reshape_attn_mask(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None): + """ + Reshape Attention Mask + [batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k] + """ + # Skip Attention Mask if all values are 1, `None` mask can speedup the computation + if attn_mask is not None and torch.all(attn_mask != 0): + attn_mask = None + + # Reshape Attention Mask + # [batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k] + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = attn_mask.to(torch.bool) + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + return attn_mask + + +class SDPABackend(AttentionBackend): + accept_output_buffer: bool = True + + @classmethod + def supports_attention_mask(cls) -> bool: + return True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [x for x in range(1024)] # todo + + @staticmethod + def get_name() -> str: + return "SDPA" + + @staticmethod + def get_impl_cls() -> type["SDPAImpl"]: + return SDPAImpl + + +class SDPAImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + attention_mask = attn_metadata.attn_mask if attn_metadata else None + output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=self.causal, + scale=self.softmax_scale, + ) + out = output.permute(0, 2, 1, 3) + return out + + def forward_xpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + return self.forward_cuda(query, key, value, attn_metadata) + + def forward_hip( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + return self.forward_cuda(query, key, value, attn_metadata) + + def forward_npu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + if attn_metadata: + attention_mask = _maybe_reshape_attn_mask(query, key, attn_metadata.attn_mask) + setattr(attn_metadata, "attn_mask", attention_mask) + return self.forward_cuda(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/attention/backends/utils/__init__.py b/vllm_omni/diffusion/attention/backends/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92c7c8027cb2456d42f9452cce7c60c0637b4638 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/utils/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utils for attention backends. +""" + +from vllm_omni.diffusion.attention.backends.utils.fa import _pad_input, _unpad_input, _upad_input + +__all__ = [ + "_pad_input", + "_unpad_input", + "_upad_input", +] diff --git a/vllm_omni/diffusion/attention/backends/utils/fa.py b/vllm_omni/diffusion/attention/backends/utils/fa.py new file mode 100644 index 0000000000000000000000000000000000000000..3344c99638146f73ecc23fed1c9f33afc156fd80 --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/utils/fa.py @@ -0,0 +1,259 @@ +# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py +import torch +import torch.nn.functional as F + +from vllm_omni.platforms import current_omni_platform + +# Flash Attention function detection with fallback chain +flash_attn_func = None +flash_attn_varlen_func = None + +if current_omni_platform.is_rocm(): + # ROCm: try Aiter first + try: + from vllm._aiter_ops import is_aiter_found_and_supported + + if is_aiter_found_and_supported(): + from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass +else: + # CUDA: try FA3 -> FA2 fallback chain + # Try FA3 from fa3-fwd PyPI package + try: + from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass + + # Fallback: Try FA3 from flash-attention source build + if flash_attn_func is None: + try: + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass + + # Fallback: Try FA2 from flash-attn package (try multiple import paths) + if flash_attn_func is None: + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass + + if flash_attn_func is None: + try: + from flash_attn.flash_attn_interface import ( # noqa: F401 + flash_attn_func, + flash_attn_varlen_func, + ) + except (ImportError, ModuleNotFoundError): + pass + +# If no FA backend available, SDPA backend will be selected at the platform level +# flash_attn_func and flash_attn_varlen_func will be None +HAS_FLASH_ATTN = flash_attn_func is not None + + +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + _index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def _pad_input(hidden_states, indices, batch, seqlen): + """ + pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. + `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong + to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in + order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into + ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, + `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + if torch.compiler.is_compiling(): + # allow PyTorch compiler to include operations that return scalar values (like .item() + torch._dynamo.config.capture_scalar_outputs = True + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> + # we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def _is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. + we have multiple increasing sequences + """ + if position_ids is None: + return False + + increasing_position_sequences = torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..83cfe84d7b6e39fbe8aedd18a5491964c8c7ac64 --- /dev/null +++ b/vllm_omni/diffusion/attention/layer.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team & Jiarui Fang +# Adapted from +# https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py + + +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend +from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy +from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention +from vllm_omni.diffusion.attention.selector import get_attn_backend +from vllm_omni.diffusion.distributed.parallel_state import get_sp_group +from vllm_omni.diffusion.forward_context import get_forward_context + +logger = init_logger(__name__) + + +class Attention(nn.Module): + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + # ulysses attention + scatter_idx: int = 2, + gather_idx: int = 1, + use_sync: bool = False, + ): + super().__init__() + self.attn_backend = get_attn_backend(-1) + self.attn_impl_cls = self.attn_backend.get_impl_cls() + self.attention = self.attn_impl_cls( + num_heads=num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + causal=causal, + num_kv_heads=num_kv_heads, + ) + # Instantiate fallback backend for float32 support + self.sdpa_fallback = SDPABackend.get_impl_cls()( + num_heads=num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + causal=causal, + num_kv_heads=num_kv_heads, + ) + self.backend_pref = None + + self.softmax_scale = softmax_scale + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.use_sync = use_sync + self.causal = causal + + self.use_ring = False + self.ring_pg = None + self.ring_runner = None + + try: + config = get_forward_context().omni_diffusion_config + self.backend_pref = config.attention_backend + if config.parallel_config.ring_degree > 1: + self.use_ring = True + try: + sp_group = get_sp_group() + self.ring_pg = sp_group.ring_group + self.ring_runner = RingParallelAttention(sp_group) + except Exception: + self.use_ring = False + self.ring_runner = None + except Exception: + self.use_ring = False + self.ring_runner = None + + self.parallel_strategy = build_parallel_attention_strategy( + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: + # 1. Prepare inputs (Communication / Resharding) + # For Ulysses: AllToAll Q/K/V; Slicing joint_q/k/v + # For Ring: Concat joint_q + query, key, value, attn_metadata, ctx = self.parallel_strategy.pre_attention(query, key, value, attn_metadata) + + # 2. Kernel Execution (Computation) + if self.use_ring: + out = self._run_ring_attention(query, key, value, attn_metadata) + else: + out = self._run_local_attention(query, key, value, attn_metadata) + + # 3. Post-processing (Reverse Communication) + # For Ulysses: AllToAll Output, and AllGather Joint Output + out = self.parallel_strategy.post_attention(out, ctx) + + return out + + def _run_local_attention(self, query, key, value, attn_metadata): + if query.dtype == torch.float32: + logger.warning_once( + f"Only SDPA supports float32. Overriding user config {type(self.attention)} " + f"attention_backend='{self.backend_pref}' to 'sdpa' for dtype={query.dtype}." + ) + return self.sdpa_fallback.forward(query, key, value, attn_metadata) + + # Fallback to standard attention + return self.attention.forward(query, key, value, attn_metadata) + + def _run_ring_attention(self, query, key, value, attn_metadata): + # Delegate to RingParallelAttention strategy if available + if self.ring_runner is not None: + return self.ring_runner.run_attention( + query, key, value, attn_metadata, softmax_scale=self.softmax_scale, causal=self.causal + ) + + raise RuntimeError("Ring attention is enabled but strategy is not RingParallelAttention") diff --git a/vllm_omni/diffusion/attention/parallel/__init__.py b/vllm_omni/diffusion/attention/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49d776f0c041e517f010e17a9d65faf32c522837 --- /dev/null +++ b/vllm_omni/diffusion/attention/parallel/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Parallel attention strategies. + +This package provides **communication / resharding strategies** for attention, +orthogonal to the **attention kernel backend** (SDPA/Flash/Sage). + +The goal is to keep `vllm_omni.diffusion.attention.layer.Attention` small and +extensible: adding a new parallelism method should not require editing the core +Attention module, only adding a new strategy and selecting it in the factory. +""" + +from .base import NoParallelAttention, ParallelAttentionContext, ParallelAttentionStrategy +from .factory import build_parallel_attention_strategy + +__all__ = [ + "ParallelAttentionStrategy", + "ParallelAttentionContext", + "NoParallelAttention", + "build_parallel_attention_strategy", +] diff --git a/vllm_omni/diffusion/attention/parallel/base.py b/vllm_omni/diffusion/attention/parallel/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c98add4228621d06878c101b4820d14aeff337a9 --- /dev/null +++ b/vllm_omni/diffusion/attention/parallel/base.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + +import torch + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + + +@dataclass(frozen=True, slots=True) +class ParallelAttentionContext: + """Opaque per-forward context returned by a parallel strategy. + + Strategies may stash whatever they need here to finish post-processing after + the attention kernel runs (e.g. reverse resharding, slicing metadata, etc.). + """ + + name: str + + +class ParallelAttentionStrategy(Protocol): + """Pluggable strategy for parallel attention communication/resharding. + + This is intentionally orthogonal to the attention *kernel* backend. + The kernel backend implements `AttentionImpl.forward()` for a given device, + while the parallel strategy implements how Q/K/V and outputs are sharded / + communicated across ranks. + """ + + @property + def enabled(self) -> bool: ... + + @property + def name(self) -> str: ... + + def pre_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, AttentionMetadata | None, ParallelAttentionContext | None]: + """Runs before the attention kernel. + + Returns possibly transformed Q/K/V and metadata, and an optional context + for `post_attention`. + """ + + def post_attention( + self, + attn_output: torch.Tensor, + ctx: ParallelAttentionContext | None, + ) -> torch.Tensor: + """Runs after the attention kernel.""" + + +class NoParallelAttention: + """Default strategy: do nothing (single device / no SP).""" + + @property + def enabled(self) -> bool: + return False + + @property + def name(self) -> str: + return "none" + + def pre_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None, + ): + return query, key, value, attn_metadata, None + + def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor: + return attn_output diff --git a/vllm_omni/diffusion/attention/parallel/factory.py b/vllm_omni/diffusion/attention/parallel/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ce95d409adf106fd5373648a1df0f5638f521339 --- /dev/null +++ b/vllm_omni/diffusion/attention/parallel/factory.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.parallel.base import NoParallelAttention, ParallelAttentionStrategy +from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention +from vllm_omni.diffusion.attention.parallel.ulysses import UlyssesParallelAttention +from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size, get_sp_group +from vllm_omni.diffusion.forward_context import get_forward_context + +logger = init_logger(__name__) + + +def build_parallel_attention_strategy( + *, + scatter_idx: int, + gather_idx: int, + use_sync: bool, +) -> ParallelAttentionStrategy: + """Select a parallel attention strategy based on current diffusion config. + + Design principle: + - Attention kernel backend selection remains in `attention/selector.py`. + - Parallel attention selection is handled here, based on distributed config + and initialized process groups. + """ + try: + cfg = get_forward_context().omni_diffusion_config + p = cfg.parallel_config + except Exception as e: + logger.debug(f"No forward context available for parallel attention strategy: {e}") + return NoParallelAttention() + + ulysses_degree = getattr(p, "ulysses_degree", 1) + ring_degree = getattr(p, "ring_degree", 1) + + try: + sp_group = get_sp_group() + # Ensure SP group is initialized and world size > 1 + if get_sequence_parallel_world_size() <= 1: + return NoParallelAttention() + except Exception as e: + # Log warning if SP is configured but group is not available + if ulysses_degree > 1 or ring_degree > 1: + logger.warning( + f"SP configured (ulysses={ulysses_degree}, ring={ring_degree}) but SP group not available: {e}. " + f"Falling back to NoParallelAttention. This may cause incorrect results." + ) + return NoParallelAttention() + + # Ulysses (or Hybrid Ulysses+Ring) + if ulysses_degree > 1: + logger.debug(f"Using UlyssesParallelAttention (ulysses_degree={ulysses_degree})") + return UlyssesParallelAttention( + sp_group=sp_group, + scatter_idx=scatter_idx, + gather_idx=gather_idx, + use_sync=use_sync, + ) + + # Pure Ring Attention + if ring_degree > 1: + logger.debug(f"Using RingParallelAttention (ring_degree={ring_degree})") + return RingParallelAttention( + sp_group=sp_group, + ) + + return NoParallelAttention() diff --git a/vllm_omni/diffusion/attention/parallel/ring.py b/vllm_omni/diffusion/attention/parallel/ring.py new file mode 100644 index 0000000000000000000000000000000000000000..090d2ef61e87c206a8ddb2f820833bbf08024729 --- /dev/null +++ b/vllm_omni/diffusion/attention/parallel/ring.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from vllm.logger import init_logger + +# import torch.distributed as dist # Not used directly here, but good practice if needed +from vllm_omni.diffusion.attention.backends.ring.ring_globals import HAS_FA3, HAS_FLASH_ATTN +from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType +from vllm_omni.diffusion.attention.parallel.base import ( + ParallelAttentionContext, + # ParallelAttentionStrategy, # Not used in type hint below currently +) +from vllm_omni.diffusion.distributed.group_coordinator import SequenceParallelGroupCoordinator + +# from vllm_omni.diffusion.attention.backends.ring_selector import AttnType # Already imported above +from vllm_omni.diffusion.forward_context import get_forward_context + +if TYPE_CHECKING: + from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + + +@dataclass(frozen=True, slots=True) +class _RingCtx(ParallelAttentionContext): + """Per-forward context for Ring sequence-parallel attention.""" + + # Ring attention typically doesn't need complex context for post-processing + # as the output is already correctly sharded along sequence dimension. + pass + + +class RingParallelAttention: + """Ring sequence-parallel strategy. + + This strategy prepares inputs for Ring Attention. + Key responsibilities: + - Concatenate joint_query (Text) to query (Image) if present. + - Keep joint_key/value separate in metadata for the Ring kernel to handle as static prefix. + """ + + def __init__( + self, + sp_group: SequenceParallelGroupCoordinator, + attn_backend_pref: str | None = None, + ) -> None: + self._sp_group = sp_group + self.attn_backend_pref = attn_backend_pref + + @property + def enabled(self) -> bool: + return True + + @property + def name(self) -> str: + return "ring" + + def pre_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None, + ): + joint_tensor_query = None + joint_strategy = "front" + + if attn_metadata is not None: + joint_tensor_query = attn_metadata.joint_query + joint_strategy = attn_metadata.joint_strategy + + if joint_tensor_query is not None: + supported_joint_strategy = ["front", "rear"] + if joint_strategy not in supported_joint_strategy: + raise ValueError(f"joint_strategy: {joint_strategy} not supported.") + + if joint_strategy == "front": + query = torch.cat([joint_tensor_query, query], dim=1) + else: + query = torch.cat([query, joint_tensor_query], dim=1) + + # Note: We do NOT concatenate joint_key/value here. + # They are preserved in attn_metadata and will be passed + # explicitly to ring_flash_attn_func. + + ctx = _RingCtx(name=self.name) + return query, key, value, attn_metadata, ctx + + def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor: + # Ring attention output is already sharded correctly along sequence dimension. + return attn_output + + def run_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None, + softmax_scale: float | None = None, + causal: bool = False, + ) -> torch.Tensor: + """Run the actual Ring Attention kernel.""" + if softmax_scale is None: + softmax_scale = query.shape[-1] ** -0.5 + + backend_pref = self.attn_backend_pref + if backend_pref is None: + try: + config = get_forward_context().omni_diffusion_config + # config might not have attention_backend attribute if not updated + backend_pref = getattr(config, "attention_backend", None) + except Exception: + backend_pref = None + + # Determine attention type with fallback chain: FA3 -> FA2 -> SDPA + # FP32 is not supported by Flash Attention, force SDPA + if query.dtype == torch.float32: + backend_pref = "sdpa" + elif not HAS_FA3 and not HAS_FLASH_ATTN: + if backend_pref != "sdpa": + logger = init_logger(__name__) + logger.warning_once("Flash Attention (FA2/FA3) is not available! Force enabling SDPA.") + backend_pref = "sdpa" + + # Extract joint tensors + joint_key, joint_value = None, None + joint_strategy = "front" + if attn_metadata is not None: + joint_key = attn_metadata.joint_key + joint_value = attn_metadata.joint_value + if attn_metadata.joint_strategy is not None: + joint_strategy = attn_metadata.joint_strategy + + if backend_pref == "sdpa" or backend_pref == "torch": + from vllm_omni.diffusion.attention.backends.ring_pytorch_attn import ring_pytorch_attn_func + + return ring_pytorch_attn_func( + query, + key, + value, + softmax_scale=softmax_scale, + causal=causal, + group=self._sp_group.ring_group, + op_type="efficient", + joint_tensor_key=joint_key, + joint_tensor_value=joint_value, + joint_strategy=joint_strategy, + ) + + from vllm_omni.diffusion.attention.backends.ring_flash_attn import ring_flash_attn_func + + # Prefer FA3 over FA2 for better performance (FA3 supports Ampere/Ada/Hopper) + attn_type = AttnType.FA3 if HAS_FA3 else AttnType.FA + + return ring_flash_attn_func( + query, + key, + value, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + group=self._sp_group.ring_group, + attn_type=attn_type, + joint_tensor_key=joint_key, + joint_tensor_value=joint_value, + joint_strategy=joint_strategy, + ) diff --git a/vllm_omni/diffusion/attention/parallel/ulysses.py b/vllm_omni/diffusion/attention/parallel/ulysses.py new file mode 100644 index 0000000000000000000000000000000000000000..66914102ef8ef801ff1dca834f360142f9ab8b41 --- /dev/null +++ b/vllm_omni/diffusion/attention/parallel/ulysses.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.distributed as dist + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.parallel.base import ParallelAttentionContext +from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D +from vllm_omni.diffusion.distributed.group_coordinator import SequenceParallelGroupCoordinator + + +@dataclass(frozen=True, slots=True) +class _UlyssesCtx(ParallelAttentionContext): + """Per-forward context for Ulysses sequence-parallel attention.""" + + ulysses_pg: dist.ProcessGroup + scatter_idx: int + gather_idx: int + use_sync: bool + joint_len: int = 0 + joint_strategy: str = "front" + + +class UlyssesParallelAttention: + """Ulysses sequence-parallel strategy (all-to-all over seq/head dims). + + This preserves the semantics previously implemented in + `Attention._forward_ulysses`: + - If `AttentionMetadata.joint_*` is provided, joint_query/key/value are + concatenated *after* all-to-all. + - joint_key/value are assumed to be replicated across SP ranks and are sliced + by ulysses head rank before concatenation. + """ + + def __init__( + self, + sp_group: SequenceParallelGroupCoordinator, + scatter_idx: int, + gather_idx: int, + use_sync: bool, + ) -> None: + self._sp_group = sp_group + self._ulysses_pg = sp_group.ulysses_group + self._scatter_idx = scatter_idx + self._gather_idx = gather_idx + self._use_sync = use_sync + + @property + def enabled(self) -> bool: + return True + + @property + def name(self) -> str: + return "ulysses" + + def pre_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None, + ): + joint_tensor_query = joint_tensor_key = joint_tensor_value = None + joint_strategy = "front" + joint_len = 0 + + if attn_metadata is not None: + joint_tensor_query = attn_metadata.joint_query + joint_tensor_key = attn_metadata.joint_key + joint_tensor_value = attn_metadata.joint_value + joint_strategy = attn_metadata.joint_strategy + + is_joint = False + if joint_tensor_query is not None and joint_tensor_key is not None and joint_tensor_value is not None: + supported_joint_strategy = ["front", "rear"] + if joint_strategy not in supported_joint_strategy: + raise ValueError( + f"joint_strategy: {joint_strategy} not supported." + f" supported joint strategy: {supported_joint_strategy}" + ) + + # Slice joint_query for this Ulysses rank + # joint_query is (B, S, H, D). We split H (dim 2). + ulysses_world_size = self._sp_group.ulysses_world_size + ulysses_rank = self._sp_group.ulysses_rank + attn_heads_per_ulysses_rank = joint_tensor_query.shape[-2] // ulysses_world_size + + # Note: We use the same heads for Q/K/V + joint_tensor_query = joint_tensor_query[ + ..., + attn_heads_per_ulysses_rank * ulysses_rank : attn_heads_per_ulysses_rank * (ulysses_rank + 1), + :, + ] + + joint_len = joint_tensor_query.shape[1] + + is_joint = True + elif joint_tensor_query is None and joint_tensor_key is None and joint_tensor_value is None: + pass + else: + raise ValueError("joint_query, joint_key, and joint_value should be None or not None simultaneously.") + + if is_joint: + # Slice joint key/value heads for this ulysses rank. + # Using same slicing logic as query + attn_heads_per_ulysses_rank_kv = joint_tensor_key.shape[-2] // ulysses_world_size + + joint_tensor_key = joint_tensor_key[ + ..., + attn_heads_per_ulysses_rank_kv * ulysses_rank : attn_heads_per_ulysses_rank_kv * (ulysses_rank + 1), + :, + ] + joint_tensor_value = joint_tensor_value[ + ..., + attn_heads_per_ulysses_rank_kv * ulysses_rank : attn_heads_per_ulysses_rank_kv * (ulysses_rank + 1), + :, + ] + + # Update metadata with sliced tensors so Ring attention can use them if needed + if attn_metadata is not None: + attn_metadata.joint_key = joint_tensor_key + attn_metadata.joint_value = joint_tensor_value + + # (bs, seq_len/P, head_cnt, head_size) -> (bs, seq_len, head_cnt/P, head_size) + query = SeqAllToAll4D.apply(self._ulysses_pg, query, self._scatter_idx, self._gather_idx, self._use_sync) + key = SeqAllToAll4D.apply(self._ulysses_pg, key, self._scatter_idx, self._gather_idx, self._use_sync) + value = SeqAllToAll4D.apply(self._ulysses_pg, value, self._scatter_idx, self._gather_idx, self._use_sync) + + if is_joint: + # Concatenate joint query AFTER AllToAll + # Image query is now (B, S, H/P, D). Joint query is (B, S_txt, H/P, D). + # This is dimensionally consistent. + if joint_strategy == "rear": + query = torch.cat([query, joint_tensor_query], dim=1) + else: + query = torch.cat([joint_tensor_query, query], dim=1) + + # Check if Ring Attention is also active (Hybrid mode) + # If Ring is active, we should NOT concatenate joint_key/value to k/v here. + # Instead, they should remain in attn_metadata and be passed to the Ring kernel. + use_ring = self._sp_group.ring_world_size > 1 + + if is_joint and not use_ring: + # Concatenate joint key/value after all-to-all ONLY for pure Ulysses (Local Attention). + if joint_strategy == "front": + key = torch.cat([joint_tensor_key, key], dim=1) + value = torch.cat([joint_tensor_value, value], dim=1) + else: # "rear" + key = torch.cat([key, joint_tensor_key], dim=1) + value = torch.cat([value, joint_tensor_value], dim=1) + + ctx = _UlyssesCtx( + name=self.name, + ulysses_pg=self._ulysses_pg, + scatter_idx=self._scatter_idx, + gather_idx=self._gather_idx, + use_sync=self._use_sync, + joint_len=joint_len, + joint_strategy=joint_strategy, + ) + + if attn_metadata is not None: + if is_joint: + if attn_metadata.joint_attn_mask is None and attn_metadata.attn_mask is None: + attn_metadata.attn_mask = None + else: + if attn_metadata.attn_mask is None: + attn_metadata.attn_mask = torch.ones( + [query.shape[0], query.shape[1] - attn_metadata.joint_attn_mask.shape[1]], + dtype=torch.bool, + device=query.device, + ) + elif attn_metadata.joint_attn_mask is None: + attn_metadata.joint_attn_mask = torch.ones( + [query.shape[0], query.shape[1] - attn_metadata.attn_mask.shape[1]], + dtype=torch.bool, + device=query.device, + ) + attn_metadata.attn_mask = ( + torch.cat([attn_metadata.joint_attn_mask, attn_metadata.attn_mask], dim=1) + if joint_strategy == "front" + else torch.cat([attn_metadata.attn_mask, attn_metadata.joint_attn_mask], dim=1) + ) + + if attn_metadata.attn_mask is not None: + # the final attn_mask is ready, the length should be aligedn with query length + assert attn_metadata.attn_mask.shape[1] == query.shape[1], ( + f"attn_mask length: {attn_metadata.attn_mask.shape[1]} != query length: {query.shape[1]}" + ) + attn_metadata.attn_mask = attn_metadata.attn_mask.bool().contiguous() + return query, key, value, attn_metadata, ctx + + def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor: + assert isinstance(ctx, _UlyssesCtx), f"Unexpected ctx type: {type(ctx)!r}" + + # If we have joint tensors (Text), they were Head-Sliced. + # The main sequence (Image) was Sequence-Sliced. + # attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front'). + + if ctx.joint_len > 0: + joint_len = ctx.joint_len + + if ctx.joint_strategy == "front": + output_joint = attn_output[:, :joint_len] + output_img = attn_output[:, joint_len:] + else: + output_img = attn_output[:, :-joint_len] + output_joint = attn_output[:, -joint_len:] + + # 1. Process Image part: Standard Ulysses Reverse (AllToAll) + # (bs, seq_len, head_cnt/P, head_size) -> (bs, seq_len/P, head_cnt, head_size) + # SeqAllToAll4D handles: Scatter gather_idx, Gather scatter_idx. + # Forward: Scatter 2 (H), Gather 1 (S). + # Reverse: Scatter 1 (S), Gather 2 (H). + output_img = SeqAllToAll4D.apply(ctx.ulysses_pg, output_img, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync) + + # 2. Process Joint part: AllGather on Heads + # Input: (B, JointLen, H/P, D). Output: (B, JointLen, H, D). + # AllGather along dim 2. + # Ensure tensor is contiguous for all_gather (slicing may create non-contiguous views) + output_joint = output_joint.contiguous() + gathered_joint = [torch.zeros_like(output_joint) for _ in range(dist.get_world_size(ctx.ulysses_pg))] + dist.all_gather(gathered_joint, output_joint, group=ctx.ulysses_pg) + output_joint = torch.cat(gathered_joint, dim=2) + + # 3. Recombine + if ctx.joint_strategy == "front": + return torch.cat([output_joint, output_img], dim=1) + else: + return torch.cat([output_img, output_joint], dim=1) + + # Standard Ulysses Reverse + return SeqAllToAll4D.apply(ctx.ulysses_pg, attn_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync) diff --git a/vllm_omni/diffusion/attention/selector.py b/vllm_omni/diffusion/attention/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..920396e66b97b1db8fdc3aaa51e427faa05a1044 --- /dev/null +++ b/vllm_omni/diffusion/attention/selector.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Diffusion attention backend selector. + +This module provides the interface for selecting diffusion attention backends. +The actual backend selection logic is delegated to the platform layer +(vllm_omni.platforms), similar to how vLLM handles attention backend selection. + +Usage: + from vllm_omni.diffusion.attention.selector import get_attn_backend + + # Get the appropriate backend for current platform + backend_cls = get_attn_backend(head_size=64) + + # Or override via environment variable + # export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN +""" + +import importlib +import os +from functools import cache + +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionBackend, +) + +logger = init_logger(__name__) + + +def _load_backend_cls(cls_path: str) -> type[AttentionBackend]: + """Load a backend class from its fully qualified path. + + Args: + cls_path: Fully qualified class path (e.g., + "vllm_omni.diffusion.attention.backends.sdpa.SDPABackend") + + Returns: + The loaded backend class + """ + module_path, class_name = cls_path.rsplit(".", 1) + try: + module = importlib.import_module(module_path) + backend_class = getattr(module, class_name) + return backend_class + except ImportError as e: + raise ImportError(f"Failed to import module {module_path}: {e}") + except AttributeError as e: + raise AttributeError(f"Class {class_name} not found in module: {e}") + + +@cache +def get_attn_backend(head_size: int) -> type[AttentionBackend]: + """ + Get attention backend for diffusion models. + + The backend selection is delegated to the current platform + (vllm_omni.platforms.current_omni_platform), which selects the + appropriate backend based on: + 1. User override via DIFFUSION_ATTENTION_BACKEND environment variable + 2. Platform-specific defaults and capabilities + + This is similar to how vLLM's get_attn_backend_cls works, where the + platform layer decides which backend to use based on hardware capabilities. + + Args: + head_size: Head size for attention computation (may affect backend selection) + + Returns: + The selected attention backend class + """ + from vllm_omni.platforms import current_omni_platform + + # Check environment variable for user override + selected_backend = os.environ.get("DIFFUSION_ATTENTION_BACKEND") + + # Delegate to platform for backend selection + backend_cls_path = current_omni_platform.get_diffusion_attn_backend_cls( + selected_backend=selected_backend, + head_size=head_size, + ) + + return _load_backend_cls(backend_cls_path) diff --git a/vllm_omni/diffusion/cache/__init__.py b/vllm_omni/diffusion/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5968f612a460d4805d0b3707c2c9b6bd40657b6 --- /dev/null +++ b/vllm_omni/diffusion/cache/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Cache module for diffusion model inference acceleration. + +This module provides a unified cache backend system for different caching strategies: +- TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching +- cache-dit: DBCache, SCM, and TaylorSeer caching strategies + +Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig. +""" + +from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.cache.teacache import ( + CacheContext, + TeaCacheConfig, + apply_teacache_hook, +) +from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend + +__all__ = [ + "CacheBackend", + "TeaCacheConfig", + "CacheContext", + "TeaCacheBackend", + "apply_teacache_hook", +] diff --git a/vllm_omni/diffusion/cache/base.py b/vllm_omni/diffusion/cache/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d7807622497def3d5ee059ec056acde98d32883e --- /dev/null +++ b/vllm_omni/diffusion/cache/base.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Base cache backend interface for diffusion models. + +This module defines the abstract base class that all cache backends must implement. +Cache backends provide a unified interface for applying different caching strategies +to transformer models. + +Main cache backend implementations: +1. CacheDiTBackend: Implements cache-dit acceleration (DBCache, SCM, TaylorSeer) using + the cache-dit library. Inherits from CacheBackend. Used via cache_backend="cache_dit". +2. TeaCacheBackend: Hook-based backend for TeaCache acceleration. Inherits from + CacheBackend. Used via cache_backend="tea_cache". + +All backends implement the same interface: +- enable(pipeline): Enable cache on the pipeline +- refresh(pipeline, num_inference_steps, verbose): Refresh cache state +- is_enabled(): Check if cache is enabled +""" + +from abc import ABC, abstractmethod +from typing import Any + +import torch.nn as nn + +from vllm_omni.diffusion.data import DiffusionCacheConfig + + +class CacheBackend(ABC): + """ + Abstract base class for cache backends. + + All cache backend implementations (CacheDiTBackend, TeaCacheBackend, etc.) inherit + from this base class and implement the enable() and refresh() methods to manage + cache lifecycle. + + Cache backends apply caching strategies to transformer models to accelerate + inference. Different backends use different underlying mechanisms (e.g., cache-dit + library for CacheDiTBackend, hooks for TeaCacheBackend), but all share the same + unified interface. + + Attributes: + config: DiffusionCacheConfig instance containing cache-specific configuration parameters + enabled: Boolean flag indicating whether cache is enabled (set to True after enable() is called) + """ + + def __init__(self, config: DiffusionCacheConfig): + """ + Initialize cache backend with configuration. + + Args: + config: DiffusionCacheConfig instance with cache-specific parameters + """ + self.config = config + self.enabled = False + + @abstractmethod + def enable(self, pipeline: Any) -> None: + """ + Enable cache on the pipeline. + + This method applies the caching strategy to the transformer(s) in the pipeline. + The specific implementation depends on the backend (e.g., hooks for TeaCacheBackend, + cache-dit library for CacheDiTBackend). Called once during pipeline initialization. + + Args: + pipeline: Diffusion pipeline instance. The backend can extract: + - transformer: via pipeline.transformer + - model_type: via pipeline.__class__.__name__ + """ + raise NotImplementedError("Subclasses must implement enable()") + + @abstractmethod + def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """ + Refresh cache state for new generation. + + This method should clear any cached values and reset counters/accumulators. + Called at the start of each generation to ensure clean state. + + Args: + pipeline: Diffusion pipeline instance. The backend can extract: + - transformer: via pipeline.transformer + num_inference_steps: Number of inference steps for the current generation. + May be used for cache context updates. + verbose: Whether to log refresh operations (default: True) + """ + raise NotImplementedError("Subclasses must implement refresh()") + + def is_enabled(self) -> bool: + """ + Check if cache is enabled on this backend. + + Returns: + True if cache is enabled, False otherwise. + """ + return self.enabled + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(config={self.config})" + + +class CachedTransformer(nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.do_true_cfg = False + + def __init_subclass__(cls, enable_separate_cfg: bool = True, **kwargs): + cls.enable_separate_cfg = enable_separate_cfg + super().__init_subclass__(**kwargs) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..c922121e9aef6ca8be1b4882f879882ef74c16c2 --- /dev/null +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -0,0 +1,923 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +cache-dit integration backend for vllm-omni. + +This module provides a CacheDiTBackend class to enable cache-dit acceleration on diffusion +pipelines in vllm-omni, supporting both single and dual-transformer architectures. +""" + +import functools +from collections.abc import Callable +from contextlib import ExitStack +from typing import Any, Optional + +import cache_dit +import torch +from cache_dit import BlockAdapter, DBCacheConfig, ForwardPattern, ParamsModifier, TaylorSeerCalibratorConfig +from cache_dit.caching.block_adapters import FakeDiffusionPipeline +from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter +from cache_dit.caching.cache_blocks.pattern_0_1_2 import CachedBlocks_Pattern_0_1_2 +from cache_dit.caching.cache_contexts import BasicCacheConfig +from cache_dit.caching.cache_contexts.cache_manager import CachedContextManager +from vllm.logger import init_logger + +from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.data import DiffusionCacheConfig, OmniDiffusionConfig + +logger = init_logger(__name__) + + +# Small helper to centralize cache-dit summaries. +def cache_summary(pipeline: Any, details: bool = True) -> None: + cache_dit.summary(pipeline.transformer, details=details) + if hasattr(pipeline, "transformer_2"): + cache_dit.summary(pipeline.transformer_2, details=details) + + +# Registry of custom cache-dit enablers for specific models +# Maps pipeline names to their cache-dit enablement functions +# Models in this registry require custom handling (e.g., dual-transformer architectures) +# Will be populated after function definitions +CUSTOM_DIT_ENABLERS: dict[str, Callable] = {} + + +def _build_db_cache_config(cache_config: Any) -> DBCacheConfig: + """Build DBCacheConfig with optional SCM (Step Computation Masking) support. + + Args: + cache_config: DiffusionCacheConfig instance. + + Returns: + DBCacheConfig instance with SCM support if configured. + """ + + return DBCacheConfig( + # we will refresh the context when gets num_inference_steps in the first inference request + num_inference_steps=None, + Fn_compute_blocks=cache_config.Fn_compute_blocks, + Bn_compute_blocks=cache_config.Bn_compute_blocks, + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, + residual_diff_threshold=cache_config.residual_diff_threshold, + ) + + +def enable_cache_for_wan22(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Wan2.2 dual-transformer architecture. + + Wan2.2 uses two transformers (transformer and transformer_2) that need + to be enabled together using BlockAdapter. + + Args: + pipeline: The Wan2.2 pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + + cache_dit.enable_cache( + BlockAdapter( + transformer=[ + pipeline.transformer, + pipeline.transformer_2, + ], + blocks=[ + pipeline.transformer.blocks, + pipeline.transformer_2.blocks, + ], + forward_pattern=[ + ForwardPattern.Pattern_2, + ForwardPattern.Pattern_2, + ], + params_modifiers=[ + # high-noise transformer only have 30% steps + ParamsModifier( + cache_config=DBCacheConfig().reset( + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + ), + ), + ParamsModifier( + cache_config=DBCacheConfig().reset( + max_warmup_steps=2, + max_cached_steps=20, + ), + ), + ], + has_separate_cfg=True, + ), + cache_config=DBCacheConfig( + Fn_compute_blocks=cache_config.Fn_compute_blocks, + Bn_compute_blocks=cache_config.Bn_compute_blocks, + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, + residual_diff_threshold=cache_config.residual_diff_threshold, + num_inference_steps=None, + ), + ) + + # from https://github.com/vipshop/cache-dit/pull/542 + def _split_inference_steps(num_inference_steps: int) -> tuple[int, int]: + """Split inference steps into high-noise and low-noise steps for Wan2.2. + + This is an internal helper function specific to Wan2.2's dual-transformer + architecture that uses boundary_ratio to determine the split point. + + Args: + num_inference_steps: Total number of inference steps. + + Returns: + A tuple of (num_high_noise_steps, num_low_noise_steps). + """ + if pipeline.boundary_ratio is not None: + boundary_timestep = pipeline.boundary_ratio * pipeline.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + # Set timesteps to calculate the split + device = next(pipeline.transformer.parameters()).device + pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = pipeline.scheduler.timesteps + num_high_noise_steps = 0 # high-noise steps for transformer + for t in timesteps: + if boundary_timestep is None or t >= boundary_timestep: + num_high_noise_steps += 1 + # low-noise steps for transformer_2 + num_low_noise_steps = num_inference_steps - num_high_noise_steps + return num_high_noise_steps, num_low_noise_steps + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for both transformers with new num_inference_steps. + + Args: + pipeline: The Wan2.2 pipeline instance. + num_inference_steps: New number of inference steps. + """ + + num_high_noise_steps, num_low_noise_steps = _split_inference_steps(num_inference_steps) + # Refresh context for high-noise transformer + if cache_config.scm_steps_mask_policy is None: + # cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose) + cache_dit.refresh_context( + pipeline.transformer, + num_inference_steps=num_high_noise_steps, + verbose=verbose, + ) + cache_dit.refresh_context( + pipeline.transformer_2, + num_inference_steps=num_low_noise_steps, + verbose=verbose, + ) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_high_noise_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, total_steps=num_high_noise_steps + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + cache_dit.refresh_context( + pipeline.transformer_2, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_low_noise_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, total_steps=num_low_noise_steps + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + +def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for LongCatImage pipeline. + + Args: + pipeline: The LongCatImage pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + """ + # Build DBCacheConfig for transformer + db_cache_config = _build_db_cache_config(cache_config) + + calibrator = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Build ParamsModifier for transformer + modifier = ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator, + ) + + logger.info( + f"Enabling cache-dit on LongCatImage transformer with BlockAdapter: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit using BlockAdapter for transformer + cache_dit.enable_cache( + ( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1], + params_modifiers=[modifier], + ) + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The LongCatImage pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + +def enable_cache_for_flux(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Flux dual-transformer architecture. + + Flux uses two transformers (transformer and transformer_2) that need + to be enabled together using BlockAdapter. + + Args: + pipeline: The Flux pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + raise NotImplementedError("cache-dit is not implemented for Flux pipeline.") + + +def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for StableDiffusion3Pipeline. + + Args: + pipeline: The StableDiffusion3 pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + """ + # Build DBCacheConfig for transformer + db_cache_config = _build_db_cache_config(cache_config) + + calibrator = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Build ParamsModifier for transformer + modifier = ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator, + ) + + logger.info( + f"Enabling cache-dit on StableDiffusion3 transformer with BlockAdapter: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit using BlockAdapter for transformer + cache_dit.enable_cache( + ( + BlockAdapter( + transformer=pipeline.transformer, + blocks=pipeline.transformer.transformer_blocks, + forward_pattern=ForwardPattern.Pattern_1, + params_modifiers=[modifier], + ) + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The LongCatImage pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + +def enable_cache_for_dit(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for regular single-transformer DiT models. + + Args: + pipeline: The diffusion pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + # Build DBCacheConfig with optional SCM support + db_cache_config = _build_db_cache_config(cache_config) + + # Build calibrator config if TaylorSeer is enabled + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + logger.info( + f"Enabling cache-dit on transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit on the transformer + cache_dit.enable_cache( + pipeline.transformer, + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The diffusion pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + +class BagelCachedContextManager(CachedContextManager): + """ + Custom CachedContextManager for Bagel that safely handles NaiveCache objects + (mapped to encoder_hidden_states) by skipping tensor operations on them. + """ + + @torch.compiler.disable + def apply_cache( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + prefix: str = "Bn", + encoder_prefix: str = "Bn_encoder", + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Allow Bn and Fn prefix to be used for residual cache. + if "Bn" in prefix: + hidden_states_prev = self.get_Bn_buffer(prefix) + else: + hidden_states_prev = self.get_Fn_buffer(prefix) + + assert hidden_states_prev is not None, f"{prefix}_buffer must be set before" + + if self.is_cache_residual(): + hidden_states = hidden_states_prev + hidden_states + else: + # If cache is not residual, we use the hidden states directly + hidden_states = hidden_states_prev + + hidden_states = hidden_states.contiguous() + + if encoder_hidden_states is not None: + if "Bn" in encoder_prefix: + encoder_hidden_states_prev = self.get_Bn_encoder_buffer(encoder_prefix) + else: + encoder_hidden_states_prev = self.get_Fn_encoder_buffer(encoder_prefix) + + if encoder_hidden_states_prev is not None: + if self.is_encoder_cache_residual(): + # FIX: Check if encoder_hidden_states is a tensor before adding + if isinstance(encoder_hidden_states, torch.Tensor) and isinstance( + encoder_hidden_states_prev, torch.Tensor + ): + encoder_hidden_states = encoder_hidden_states_prev + encoder_hidden_states + else: + # If encoder cache is not residual, we use the encoder hidden states directly + encoder_hidden_states = encoder_hidden_states_prev + + # FIX: Check if encoder_hidden_states is a tensor before calling contiguous + if isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states.contiguous() + + return hidden_states, encoder_hidden_states + + +class BagelCachedBlocks(CachedBlocks_Pattern_0_1_2): + """ + Custom CachedBlocks for Bagel that safely handles NaiveCache objects + by adding isinstance checks in call_Mn_blocks and compute_or_prune. + """ + + def call_Mn_blocks( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + *args, + **kwargs, + ): + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + for block in self._Mn_blocks(): + hidden_states = block( + hidden_states, + encoder_hidden_states, + *args, + **kwargs, + ) + hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states) + + # compute hidden_states residual + hidden_states = hidden_states.contiguous() + + hidden_states_residual = hidden_states - original_hidden_states + + if ( + encoder_hidden_states is not None + and original_encoder_hidden_states is not None + and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check + ): + encoder_hidden_states = encoder_hidden_states.contiguous() + encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + else: + encoder_hidden_states_residual = None + + return ( + hidden_states, + encoder_hidden_states, + hidden_states_residual, + encoder_hidden_states_residual, + ) + + def compute_or_prune( + self, + block_id: int, # Block index in the transformer blocks + # Below are the inputs to the block + block, # The transformer block to be executed + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + *args, + **kwargs, + ): + # NOTE: Although Bagel likely won't use pruning, implementing safe version just in case. + # Copy-pasted from original but adding checks. + + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + + can_use_prune = self._maybe_prune( + block_id, + hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Fn_original", + ) + + torch._dynamo.graph_break() + if can_use_prune: + self.context_manager.add_pruned_step() + hidden_states, encoder_hidden_states = self.context_manager.apply_prune( + hidden_states, + encoder_hidden_states, + prefix=( + f"{self.cache_prefix}_{block_id}_Bn_residual" + if self.context_manager.is_cache_residual() + else f"{self.cache_prefix}_Bn_hidden_states" + ), + encoder_prefix=( + f"{self.cache_prefix}_{block_id}_Bn_encoder_residual" + if self.context_manager.is_encoder_cache_residual() + else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states" + ), + ) + torch._dynamo.graph_break() + else: + # Normal steps: Compute the block and cache the residuals. + hidden_states = block( + hidden_states, + encoder_hidden_states, + *args, + **kwargs, + ) + hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states) + if not self._skip_prune(block_id): + hidden_states = hidden_states.contiguous() + hidden_states_residual = hidden_states - original_hidden_states + + if ( + encoder_hidden_states is not None + and original_encoder_hidden_states is not None + and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check + ): + encoder_hidden_states = encoder_hidden_states.contiguous() + encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + else: + encoder_hidden_states_residual = None + + self.context_manager.set_Fn_buffer( + original_hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Fn_original", + ) + if self.context_manager.is_cache_residual(): + self.context_manager.set_Bn_buffer( + hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_residual", + ) + else: + self.context_manager.set_Bn_buffer( + hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states", + ) + if encoder_hidden_states_residual is not None: + if self.context_manager.is_encoder_cache_residual(): + self.context_manager.set_Bn_encoder_buffer( + encoder_hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual", + ) + else: + self.context_manager.set_Bn_encoder_buffer( + encoder_hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states", + ) + torch._dynamo.graph_break() + + return hidden_states, encoder_hidden_states + + +class BagelCachedAdapter(CachedAdapter): + """ + Custom CachedAdapter for Bagel that uses BagelCachedContextManager and BagelCachedBlocks. + """ + + @classmethod + def create_context( + cls, + block_adapter: BlockAdapter, + **context_kwargs, + ) -> tuple[list[str], list[dict[str, Any]]]: + # Override to use BagelCachedContextManager + + BlockAdapter.assert_normalized(block_adapter) + + if BlockAdapter.is_cached(block_adapter.pipe): + return block_adapter.pipe + + # Check context_kwargs + context_kwargs = cls.check_context_kwargs(block_adapter, **context_kwargs) + + # Each Pipeline should have it's own context manager instance. + cache_config: BasicCacheConfig = context_kwargs.get("cache_config", None) + assert cache_config is not None, "cache_config can not be None." + + # Apply cache on pipeline: wrap cache context + pipe_cls_name = block_adapter.pipe.__class__.__name__ + + # USE CUSTOM CONTEXT MANAGER + context_manager = BagelCachedContextManager( + name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}", + persistent_context=isinstance(block_adapter.pipe, FakeDiffusionPipeline), + ) + + flatten_contexts, contexts_kwargs = cls.modify_context_params(block_adapter, **context_kwargs) + + block_adapter.pipe._context_manager = context_manager # instance level + + if not context_manager.persistent_context: + original_call = block_adapter.pipe.__class__.__call__ + + @functools.wraps(original_call) + def new_call(self, *args, **kwargs): + with ExitStack() as stack: + # cache context will be reset for each pipe inference + for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs): + stack.enter_context( + context_manager.enter_context( + context_manager.reset_context( + context_name, + **context_kwargs, + ), + ) + ) + outputs = original_call(self, *args, **kwargs) + cls.apply_stats_hooks(block_adapter) + return outputs + + block_adapter.pipe.__class__.__call__ = new_call + block_adapter.pipe.__class__._original_call = original_call + + else: + # Init persistent cache context for transformer + for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs): + context_manager.reset_context( + context_name, + **context_kwargs, + ) + + block_adapter.pipe.__class__._is_cached = True + + cls.apply_params_hooks(block_adapter, contexts_kwargs) + + return flatten_contexts, contexts_kwargs + + @classmethod + def collect_unified_blocks( + cls, + block_adapter: BlockAdapter, + contexts_kwargs: list[dict], + ) -> list[dict[str, torch.nn.ModuleList]]: + # Override to use BagelCachedBlocks + + BlockAdapter.assert_normalized(block_adapter) + + total_cached_blocks: list[dict[str, torch.nn.ModuleList]] = [] + assert hasattr(block_adapter.pipe, "_context_manager") + # Skipping isinstance check for ContextManager._supported_managers to avoid import issues + + for i in range(len(block_adapter.transformer)): + unified_blocks_bind_context = {} + for j in range(len(block_adapter.blocks[i])): + cache_config: BasicCacheConfig = contexts_kwargs[i * len(block_adapter.blocks[i]) + j]["cache_config"] + + # Directly instantiate BagelCachedBlocks + unified_blocks_bind_context[block_adapter.unique_blocks_name[i][j]] = torch.nn.ModuleList( + [ + BagelCachedBlocks( + # 0. Transformer blocks configuration + block_adapter.blocks[i][j], + transformer=block_adapter.transformer[i], + forward_pattern=block_adapter.forward_pattern[i][j], + check_forward_pattern=block_adapter.check_forward_pattern, + check_num_outputs=block_adapter.check_num_outputs, + # 1. Cache/Prune context configuration + cache_prefix=block_adapter.blocks_name[i][j], + cache_context=block_adapter.unique_blocks_name[i][j], + context_manager=block_adapter.pipe._context_manager, + cache_type=cache_config.cache_type, + ) + ] + ) + + total_cached_blocks.append(unified_blocks_bind_context) + + return total_cached_blocks + + +def enable_cache_for_bagel(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Bagel model (via OmniDiffusion pipeline). + + Args: + pipeline: The OmniDiffusion pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + # Build DBCacheConfig + db_cache_config = _build_db_cache_config(cache_config) + + # Build calibrator config if TaylorSeer is enabled + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Access the transformer: BagelPipeline -> Qwen2MoTForCausalLM -> Qwen2MoTModel + # BagelPipeline has self.language_model which is Qwen2MoTForCausalLM + # Qwen2MoTForCausalLM has self.model which is Qwen2MoTModel + transformer = pipeline.language_model.model + + logger.info( + f"Enabling cache-dit on Bagel transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit on the transformer + # Pattern_0 corresponds to (hidden_states, encoder_hidden_states) input, output + # Custom adapter for Bagel to handle NaiveCache correctly + # from vllm_omni.diffusion.cache.bagel_cache_adapter import BagelCachedAdapter # No longer needed + BagelCachedAdapter.apply( + BlockAdapter( + transformer=transformer, + blocks=transformer.layers, + forward_pattern=ForwardPattern.Pattern_0, + ), + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + transformer = pipeline.language_model.model + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + +# Register custom cache-dit enablers after function definitions +CUSTOM_DIT_ENABLERS.update( + { + "Wan22Pipeline": enable_cache_for_wan22, + "Wan22I2VPipeline": enable_cache_for_wan22, + "Wan22TI2VPipeline": enable_cache_for_wan22, + "FluxPipeline": enable_cache_for_flux, + "LongCatImagePipeline": enable_cache_for_longcat_image, + "LongCatImageEditPipeline": enable_cache_for_longcat_image, + "StableDiffusion3Pipeline": enable_cache_for_sd3, + "BagelPipeline": enable_cache_for_bagel, + } +) + + +class CacheDiTBackend(CacheBackend): + """Backend class for cache-dit acceleration on diffusion pipelines. + + This class implements cache-dit acceleration (DBCache, SCM, TaylorSeer) using + the cache-dit library. It inherits from CacheBackend and provides a unified + interface for managing cache-dit acceleration on diffusion models. + + Attributes: + config: Cache configuration (DiffusionCacheConfig instance), inherited from CacheBackend. + enabled: Whether cache-dit is enabled on this pipeline, inherited from CacheBackend. + _refresh_func: Internal refresh function for updating cache context. + _last_num_inference_steps: Last num_inference_steps used for refresh optimization. + """ + + def __init__(self, cache_config: Any = None): + """Initialize the cache-dit backend. + + Args: + cache_config: Cache configuration (DiffusionCacheConfig instance, dict, or None). + If None or empty, uses default DiffusionCacheConfig(). + """ + # Use default config if cache_config is not provided or is empty + if cache_config is None: + config = DiffusionCacheConfig() + elif isinstance(cache_config, dict): + # Convert dict to DiffusionCacheConfig, using defaults for missing keys + config = DiffusionCacheConfig.from_dict(cache_config) + else: + config = cache_config + + # Initialize base class with normalized config + super().__init__(config) + + # Cache-dit specific attributes + self._refresh_func: Callable[[Any, int, bool], None] | None = None + self._last_num_inference_steps: int | None = None + + def enable(self, pipeline: Any) -> None: + """Enable cache-dit on the pipeline if configured. + + This method applies cache-dit acceleration to the appropriate transformer(s) + in the pipeline. It handles both single-transformer and dual-transformer + architectures (e.g., Wan2.2). + + Args: + pipeline: The diffusion pipeline instance. + """ + + # Extract pipeline name from pipeline + pipeline_name = pipeline.__class__.__name__ + # Check if this model has a custom cache-dit enabler + if pipeline_name in CUSTOM_DIT_ENABLERS: + logger.info(f"Using custom cache-dit enabler for model: {pipeline_name}") + self._refresh_func = CUSTOM_DIT_ENABLERS[pipeline_name](pipeline, self.config) + else: + # For regular single-transformer models + self._refresh_func = enable_cache_for_dit(pipeline, self.config) + + self.enabled = True + logger.info(f"Cache-dit enabled successfully on {pipeline_name}") + + def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context with new num_inference_steps. + + This method updates the cache context when num_inference_steps changes + during inference. For dual-transformer models (e.g., Wan2.2), it automatically + splits the steps based on boundary_ratio. + + Args: + pipeline: The diffusion pipeline instance. + num_inference_steps: New number of inference steps. + verbose: Whether to log refresh operations. + """ + if not self.enabled or self._refresh_func is None: + logger.warning("Cache-dit is not enabled. Cannot refresh cache context.") + return + + # Only refresh if num_inference_steps has changed + if self._last_num_inference_steps is None or num_inference_steps != self._last_num_inference_steps: + if verbose: + logger.info(f"Refreshing cache context for transformer with num_inference_steps: {num_inference_steps}") + self._refresh_func(pipeline, num_inference_steps, verbose) + self._last_num_inference_steps = num_inference_steps + + def is_enabled(self) -> bool: + """Check if cache-dit is enabled on this pipeline. + + Returns: + True if cache-dit is enabled, False otherwise. + """ + return self.enabled + + +def may_enable_cache_dit(pipeline: Any, od_config: OmniDiffusionConfig) -> Optional["CacheDiTBackend"]: + """Enable cache-dit on the pipeline if configured (convenience function). + + This is a convenience function that creates and enables a CacheDiTBackend. + For new code, consider using CacheDiTBackend directly. + + Args: + pipeline: The diffusion pipeline instance. + od_config: OmniDiffusionConfig with cache configuration. + + Returns: + A CacheDiTBackend instance if cache-dit is enabled, None otherwise. + """ + if od_config.cache_backend != "cache-dit" or not od_config.cache_config: + return None + + backend = CacheDiTBackend(od_config.cache_config) + backend.enable(pipeline) + return backend if backend.is_enabled() else None diff --git a/vllm_omni/diffusion/cache/selector.py b/vllm_omni/diffusion/cache/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..7c09bf664755c7e9d860e411c3961c56f217111d --- /dev/null +++ b/vllm_omni/diffusion/cache/selector.py @@ -0,0 +1,38 @@ +from typing import Any + +from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.cache.cache_dit_backend import CacheDiTBackend +from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend +from vllm_omni.diffusion.data import DiffusionCacheConfig + + +def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBackend | None: + """Get cache backend instance based on cache_backend string. + + This is a selector function that routes to the appropriate backend implementation. + - cache_dit: Uses CacheDiTBackend with enable()/refresh() interface + - tea_cache: Uses TeaCacheBackend with enable()/refresh() interface + + Args: + cache_backend: Cache backend name ("cache_dit", "tea_cache", or None). + cache_config: Cache configuration (dict or DiffusionCacheConfig instance). + + Returns: + Cache backend instance (CacheDiTBackend or TeaCacheBackend) if cache_backend is set, + None otherwise. + + Raises: + ValueError: If cache_backend is unsupported. + """ + if cache_backend is None or cache_backend == "none": + return None + + if isinstance(cache_config, dict): + cache_config = DiffusionCacheConfig.from_dict(cache_config) + + if cache_backend == "cache_dit": + return CacheDiTBackend(cache_config) + elif cache_backend == "tea_cache": + return TeaCacheBackend(cache_config) + else: + raise ValueError(f"Unsupported cache backend: {cache_backend}. Supported: 'cache_dit', 'tea_cache'") diff --git a/vllm_omni/diffusion/cache/teacache/__init__.py b/vllm_omni/diffusion/cache/teacache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4ba5c6c9f3f00a2b70ec856b72c26d3f2952cd --- /dev/null +++ b/vllm_omni/diffusion/cache/teacache/__init__.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +TeaCache: Timestep Embedding Aware Cache for diffusion model acceleration. + +TeaCache speeds up diffusion inference by reusing transformer block computations +when consecutive timestep embeddings are similar. + +This implementation uses a hooks-based approach that requires zero changes to +model code. Model developers only need to add an extractor function to support +new models. + +Usage: + from vllm_omni import Omni + + omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="tea_cache", + cache_config={"rel_l1_thresh": 0.2} + ) + images = omni.generate("a cat") + + # Alternative: Using environment variable + # export DIFFUSION_CACHE_BACKEND=tea_cache +""" + +from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend +from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig +from vllm_omni.diffusion.cache.teacache.extractors import ( + CacheContext, + register_extractor, +) +from vllm_omni.diffusion.cache.teacache.hook import TeaCacheHook, apply_teacache_hook +from vllm_omni.diffusion.cache.teacache.state import TeaCacheState + +__all__ = [ + "TeaCacheBackend", + "TeaCacheConfig", + "TeaCacheState", + "TeaCacheHook", + "apply_teacache_hook", + "register_extractor", + "CacheContext", +] diff --git a/vllm_omni/diffusion/cache/teacache/backend.py b/vllm_omni/diffusion/cache/teacache/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..bf328d43d8b4b61f036b9532bec9058e38246ad4 --- /dev/null +++ b/vllm_omni/diffusion/cache/teacache/backend.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +TeaCache backend implementation. + +This module provides the TeaCache backend that implements the CacheBackend +interface using the hooks-based TeaCache system. +""" + +from typing import Any + +from vllm.logger import init_logger + +from vllm_omni.diffusion.cache.base import CacheBackend +from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig +from vllm_omni.diffusion.cache.teacache.hook import TeaCacheHook, apply_teacache_hook +from vllm_omni.diffusion.data import DiffusionCacheConfig + +logger = init_logger(__name__) + + +def enable_bagel_teacache(pipeline: Any, config: DiffusionCacheConfig) -> None: + """ + Enable TeaCache for Bagel model. + """ + teacache_config = TeaCacheConfig( + transformer_type="Bagel", + rel_l1_thresh=config.rel_l1_thresh, + coefficients=config.coefficients, + ) + transformer = pipeline.bagel + original_forward_flow = transformer._forward_flow + + import types + + def forward_alias(self, *args, **kwargs): + return original_forward_flow(*args, **kwargs) + + transformer.forward = types.MethodType(forward_alias, transformer) + apply_teacache_hook(transformer, teacache_config) + transformer._forward_flow = transformer.forward + pipeline.transformer = transformer + + logger.info( + f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, " + f"transformer_class={teacache_config.transformer_type}" + ) + + +CUSTOM_TEACACHE_ENABLERS = {"BagelPipeline": enable_bagel_teacache} + + +class TeaCacheBackend(CacheBackend): + """ + TeaCache implementation using hooks. + + TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique + that speeds up diffusion inference by reusing transformer block computations + when consecutive timestep embeddings are similar. + + The backend applies TeaCache hooks to the transformer which intercept the + forward pass and implement the caching logic transparently. + + Example: + >>> from vllm_omni.diffusion.data import DiffusionCacheConfig + >>> backend = TeaCacheBackend(DiffusionCacheConfig(rel_l1_thresh=0.2)) + >>> backend.enable(pipeline) + >>> # Generate with cache enabled + >>> backend.refresh(pipeline, num_inference_steps=50) # Refresh before each generation + >>> # Access config attributes: backend.config.rel_l1_thresh + """ + + def enable(self, pipeline: Any) -> None: + """ + Enable TeaCache on transformer using hooks. + + This creates a TeaCacheConfig from the backend's DiffusionCacheConfig + and applies the TeaCache hook to the transformer. + + Args: + pipeline: Diffusion pipeline instance. Extracts transformer and transformer_type: + - transformer: pipeline.transformer + - transformer_type: pipeline.transformer.__class__.__name__ + """ + # Helper to get pipeline class name + pipeline_type = pipeline.__class__.__name__ + + # Check for pipeline-level custom enablers + if pipeline_type in CUSTOM_TEACACHE_ENABLERS: + logger.info(f"Using custom TeaCache enabler for model: {pipeline_type}") + CUSTOM_TEACACHE_ENABLERS[pipeline_type](pipeline, self.config) + else: + transformer = pipeline.transformer + transformer_type = transformer.__class__.__name__ + + # Create TeaCacheConfig from DiffusionCacheConfig with transformer_type + # Access parameters via attribute access: config.rel_l1_thresh + # rel_l1_thresh already has a default value of 0.2 in DiffusionCacheConfig + try: + teacache_config = TeaCacheConfig( + transformer_type=transformer_type, + rel_l1_thresh=self.config.rel_l1_thresh, + coefficients=self.config.coefficients, + ) + except Exception as e: + logger.error(f"Failed to create TeaCacheConfig: {e}") + raise ValueError( + f"Invalid TeaCache configuration: {e}. " + f"Expected keys: rel_l1_thresh, coefficients (optional). " + f"transformer_type is automatically extracted from pipeline.transformer.__class__.__name__." + ) + + # Apply hook to transformer + apply_teacache_hook(transformer, teacache_config) + + logger.info( + f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, " + f"transformer_class={teacache_config.transformer_type}" + ) + + # Mark as enabled + self.enabled = True + + def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """ + Refresh TeaCache state for new generation. + + Clears all cached residuals and resets counters/accumulators. + Should be called before each generation to ensure clean state. + + Args: + pipeline: Diffusion pipeline instance. Extracts transformer via pipeline.transformer. + num_inference_steps: Number of inference steps for the current generation. + Currently not used by TeaCache but accepted for interface consistency. + verbose: Whether to log refresh operations (default: True) + """ + # Extract transformer from pipeline + transformer = pipeline.transformer + + if hasattr(transformer, "_hook_registry"): + hook = transformer._hook_registry.get_hook(TeaCacheHook._HOOK_NAME) + if hook is not None: + transformer._hook_registry.reset_hook(TeaCacheHook._HOOK_NAME) + if verbose: + logger.debug(f"TeaCache state refreshed (num_inference_steps={num_inference_steps})") + else: + if verbose: + logger.warning("TeaCache hook not found, nothing to refresh") + else: + if verbose: + logger.warning("Transformer has no hook registry, TeaCache may not be applied") diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..cdacf49dc4c672b3fbcf7304721c14de9d77cbde --- /dev/null +++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types +from typing import Any + +import numpy as np +import torch +from vllm.config import LoadConfig + +from vllm_omni.diffusion.cache.teacache.extractors import get_extractor +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.hooks import HookRegistry, ModelHook +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline + + +class DataCollectionHook(ModelHook): + """Hook to collect modulated inputs and model outputs for TeaCache coefficient estimation.""" + + _HOOK_NAME = "teacache_collector" + + def __init__(self, transformer_type: str): + super().__init__() + self.transformer_type = transformer_type + self.extractor_fn = None + self.current_trajectory: list[tuple[np.ndarray, np.ndarray]] = [] + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + self.extractor_fn = get_extractor(self.transformer_type) + return module + + def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: + ctx = self.extractor_fn(module, *args, **kwargs) + modulated_input_cpu = ctx.modulated_input.detach().cpu().numpy() + + outputs = ctx.run_transformer_blocks() + ctx.hidden_states = outputs[0] + if len(outputs) > 1 and ctx.encoder_hidden_states is not None: + ctx.encoder_hidden_states = outputs[1] + + model_output_cpu = ctx.hidden_states.detach().cpu().numpy() + self.current_trajectory.append((modulated_input_cpu, model_output_cpu)) + + return ctx.postprocess(ctx.hidden_states) + + def start_collection(self): + self.current_trajectory = [] + + def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]: + return list(self.current_trajectory) + + +class BagelAdapter: + """Adapter for Bagel model.""" + + @staticmethod + def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline: + od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) + od_config.model_class_name = "BagelPipeline" + + pipeline = BagelPipeline(od_config=od_config) + loader = DiffusersPipelineLoader(LoadConfig()) + loader.load_weights(pipeline) + pipeline.to(device) + return pipeline + + @staticmethod + def get_transformer(pipeline: Any) -> tuple[Any, str]: + return pipeline.bagel, "Bagel" + + @staticmethod + def install_hook(transformer: Any, hook: DataCollectionHook) -> None: + original_forward_flow = transformer._forward_flow + + def forward_alias(self, *args, **kwargs): + return original_forward_flow(*args, **kwargs) + + transformer.forward = types.MethodType(forward_alias, transformer) + registry = HookRegistry.get_or_create(transformer) + registry.register_hook(hook._HOOK_NAME, hook) + transformer._forward_flow = transformer.forward + + +class DefaultAdapter: + """Default adapter for standard diffusers pipelines.""" + + @staticmethod + def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: + raise NotImplementedError("DefaultAdapter.load_pipeline not implemented") + + @staticmethod + def get_transformer(pipeline: Any) -> tuple[Any, str]: + return pipeline.transformer, pipeline.transformer.__class__.__name__ + + @staticmethod + def install_hook(transformer: Any, hook: DataCollectionHook) -> None: + registry = HookRegistry.get_or_create(transformer) + registry.register_hook(hook._HOOK_NAME, hook) + + +_MODEL_ADAPTERS: dict[str, type] = { + "Bagel": BagelAdapter, +} + +_EPSILON = 1e-6 + + +def calculate_relative_l1(tensor_current: np.ndarray, tensor_next: np.ndarray) -> float: + """Calculate relative L1 distance (Eq. 4 from TeaCache paper).""" + diff = np.abs(tensor_current - tensor_next).sum() + norm = np.abs(tensor_current).sum() + _EPSILON + return diff / norm + + +def estimate_teacache_coefficients( + collected_data: list[list[tuple[np.ndarray, np.ndarray]]], poly_order: int = 4 +) -> list[float]: + """Estimate polynomial coefficients for TeaCache using np.polyfit.""" + input_diffs, output_diffs = [], [] + + for sample in collected_data: + for t in range(len(sample) - 1): + feat_in_curr, feat_out_curr = sample[t] + feat_in_next, feat_out_next = sample[t + 1] + input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next)) + output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next)) + + x = np.array(input_diffs, dtype=np.float64) + y = np.array(output_diffs, dtype=np.float64) + + print("Data statistics:") + print(f" Count: {len(x)}") + print(f" Input Diffs (x): min={x.min():.4e}, max={x.max():.4e}, mean={x.mean():.4e}") + print(f" Output Diffs (y): min={y.min():.4e}, max={y.max():.4e}, mean={y.mean():.4e}") + + return np.polyfit(x, y, poly_order).tolist() + + +class TeaCacheCoefficientEstimator: + """Model-agnostic helper class to collect data and estimate TeaCache coefficients.""" + + def __init__( + self, + model_path: str, + model_type: str = "Bagel", + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + # Add validation here ⬇️ + if model_type not in _MODEL_ADAPTERS: + available_types = list(_MODEL_ADAPTERS.keys()) + raise ValueError( + f"Unsupported model_type: '{model_type}'. " + f"Available types: {available_types}. " + f"To add support for a new model, add an entry to _MODEL_ADAPTERS." + ) + + adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter) + self.pipeline = adapter.load_pipeline(model_path, device, dtype) + self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline) + self.hook = DataCollectionHook(self.transformer_type) + self.collected_data: list[list[tuple[np.ndarray, np.ndarray]]] = [] + adapter.install_hook(self.transformer, self.hook) + + def collect_from_prompt(self, prompt: str, **generate_kwargs): + self.hook.start_collection() + from vllm_omni.diffusion.request import OmniDiffusionRequest + + req = OmniDiffusionRequest( + prompt=prompt, + num_inference_steps=generate_kwargs.get("num_inference_steps", 20), + seed=generate_kwargs.get("seed", 42), + ) + self.pipeline.forward(req) + trajectory = self.hook.stop_collection() + if trajectory: + self.collected_data.append(trajectory) + + def estimate(self, poly_order: int = 4) -> list[float]: + """Estimate polynomial coefficients from collected data. + + Args: + poly_order: Order of polynomial fit (default: 4) + + Returns: + List of polynomial coefficients [a_n, a_{n-1}, ..., a_1, a_0] + + Raises: + RuntimeError: If no data has been collected + """ + if not self.collected_data: + raise RuntimeError( + "No data collected for coefficient estimation. " + "Call collect_from_prompt() at least once before calling estimate()." + ) + return estimate_teacache_coefficients(self.collected_data, poly_order) diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py new file mode 100644 index 0000000000000000000000000000000000000000..30b1745f47881ad5006eb4c70691039b8949adab --- /dev/null +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +# Model-specific polynomial coefficients for rescaling L1 distances +# These coefficients account for model-specific characteristics in how embeddings change +# Source: TeaCache paper and ComfyUI-TeaCache empirical tuning +_MODEL_COEFFICIENTS = { + # FLUX transformer coefficients from TeaCache paper + "FluxTransformer2DModel": [ + 4.98651651e02, + -2.83781631e02, + 5.58554382e01, + -3.82021401e00, + 2.64230861e-01, + ], + # Qwen-Image transformer coefficients from ComfyUI-TeaCache + # Tuned specifically for Qwen's dual-stream transformer architecture + # Used for all Qwen-Image Family pipelines, in general + "QwenImageTransformer2DModel": [ + -4.50000000e02, + 2.80000000e02, + -4.50000000e01, + 3.20000000e00, + -2.00000000e-02, + ], + # Bagel transformer coefficients + # Using Qwen's coefficients as reasonable default given shared architecture + "Bagel": [1.33313129e06, -1.68644226e05, 7.95050740e03, -1.63747873e02, 1.26352397e00], + # Z-Image transformer coefficients + # Copied from Qwen-Image, need to be tuned specifically for Z-Image in future + "ZImageTransformer2DModel": [ + -4.50000000e02, + 2.80000000e02, + -4.50000000e01, + 3.20000000e00, + -2.00000000e-02, + ], +} + + +@dataclass +class TeaCacheConfig: + """ + Configuration for TeaCache applied to transformer models. + + TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up + diffusion model inference by reusing transformer block computations when consecutive + timestep embeddings are similar. + + Args: + rel_l1_thresh: Threshold for accumulated relative L1 distance. When below threshold, + cached residual is reused. Values in [0.1, 0.3] work best: + - 0.2: ~1.5x speedup with minimal quality loss + - 0.4: ~1.8x speedup with slight quality loss + - 0.6: ~2.0x speedup with noticeable quality loss + coefficients: Polynomial coefficients for rescaling L1 distance. If None, uses + model-specific defaults based on transformer_type. + transformer_type: Transformer class name (e.g., "QwenImageTransformer2DModel"). + Auto-detected from pipeline.transformer.__class__.__name__ in backend. + Defaults to "QwenImageTransformer2DModel". + """ + + rel_l1_thresh: float = 0.2 + coefficients: list[float] | None = None + transformer_type: str = "QwenImageTransformer2DModel" + + def __post_init__(self) -> None: + """Validate and set default coefficients.""" + if self.rel_l1_thresh <= 0: + raise ValueError(f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}") + + if self.coefficients is None: + # Use model-specific coefficients, explicitly check if the type exists or not + if self.transformer_type not in _MODEL_COEFFICIENTS: + raise KeyError( + f"Cannot find coefficients for {self.transformer_type}. " + f"Supported: {list(_MODEL_COEFFICIENTS.keys())}" + ) + self.coefficients = _MODEL_COEFFICIENTS[self.transformer_type] + + if len(self.coefficients) != 5: + raise ValueError(f"coefficients must contain exactly 5 elements, got {len(self.coefficients)}") diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..78029791916784c335bfaadc3645175c4b29af93 --- /dev/null +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -0,0 +1,650 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Model-specific extractors for TeaCache. + +This module provides a registry of extractor functions that know how to extract +modulated inputs from different transformer architectures. Adding support for +a new model requires only adding a new extractor function to the registry. + +With Option B enhancement, extractors now return a CacheContext object containing +all model-specific information needed for generic caching, including preprocessing, +transformer execution, and postprocessing logic. +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + +from vllm_omni.diffusion.forward_context import get_forward_context + + +@dataclass +class CacheContext: + """ + Context object containing all model-specific information for caching. + + This allows the TeaCacheHook to remain completely generic - all model-specific + logic is encapsulated in the extractor that returns this context. + + Attributes: + modulated_input: Tensor used for cache decision (similarity comparison). + Must be a torch.Tensor extracted from the first transformer block, + typically after applying normalization and modulation. + + hidden_states: Current hidden states (will be modified by caching). + Must be a torch.Tensor representing the main image/latent states + after preprocessing but before transformer blocks. + + encoder_hidden_states: Optional encoder states (for dual-stream models). + Set to None for single-stream models (e.g., Flux). + For dual-stream models (e.g., Qwen), contains text encoder outputs. + + temb: Timestep embedding tensor. + Must be a torch.Tensor containing the timestep conditioning. + + run_transformer_blocks: Callable that executes model-specific transformer blocks. + Signature: () -> tuple[torch.Tensor, ...] + + Returns: + tuple containing: + - [0]: processed hidden_states (required) + - [1]: processed encoder_hidden_states (optional, only for dual-stream) + + Example for single-stream: + def run_blocks(): + h = hidden_states + for block in module.transformer_blocks: + h = block(h, temb=temb) + return (h,) + + Example for dual-stream: + def run_blocks(): + h, e = hidden_states, encoder_hidden_states + for block in module.transformer_blocks: + e, h = block(h, e, temb=temb) + return (h, e) + + postprocess: Callable that does model-specific output postprocessing. + Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple] + + Takes the processed hidden_states and applies final transformations + (normalization, projection) to produce the model output. + + Example: + def postprocess(h): + h = module.norm_out(h, temb) + output = module.proj_out(h) + return Transformer2DModelOutput(sample=output) + + extra_states: Optional dict for additional model-specific state. + Use this for models that need to pass additional context beyond + the standard fields. + """ + + modulated_input: torch.Tensor + hidden_states: torch.Tensor + encoder_hidden_states: torch.Tensor | None + temb: torch.Tensor + run_transformer_blocks: Callable[[], tuple[torch.Tensor, ...]] + postprocess: Callable[[torch.Tensor], Any] + extra_states: dict[str, Any] | None = None + + def validate(self) -> None: + """ + Validate that the CacheContext contains valid data. + + Raises: + TypeError: If fields have wrong types + ValueError: If tensors have invalid properties + RuntimeError: If callables fail basic invocation tests + + This method should be called after creating a CacheContext to catch + common developer errors early with clear error messages. + """ + # Validate tensor fields + if not isinstance(self.modulated_input, torch.Tensor): + raise TypeError(f"modulated_input must be torch.Tensor, got {type(self.modulated_input)}") + + if not isinstance(self.hidden_states, torch.Tensor): + raise TypeError(f"hidden_states must be torch.Tensor, got {type(self.hidden_states)}") + + if self.encoder_hidden_states is not None and not isinstance(self.encoder_hidden_states, torch.Tensor): + raise TypeError( + f"encoder_hidden_states must be torch.Tensor or None, got {type(self.encoder_hidden_states)}" + ) + + if not isinstance(self.temb, torch.Tensor): + raise TypeError(f"temb must be torch.Tensor, got {type(self.temb)}") + + # Validate callables + if not callable(self.run_transformer_blocks): + raise TypeError(f"run_transformer_blocks must be callable, got {type(self.run_transformer_blocks)}") + + if not callable(self.postprocess): + raise TypeError(f"postprocess must be callable, got {type(self.postprocess)}") + + # Validate tensor shapes are compatible + if self.modulated_input.shape[0] != self.hidden_states.shape[0]: + raise ValueError( + f"Batch size mismatch: modulated_input has batch size " + f"{self.modulated_input.shape[0]}, but hidden_states has " + f"{self.hidden_states.shape[0]}" + ) + + # Validate devices match + if self.modulated_input.device != self.hidden_states.device: + raise ValueError( + f"Device mismatch: modulated_input on {self.modulated_input.device}, " + f"hidden_states on {self.hidden_states.device}" + ) + + +def extract_qwen_context( + module: nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + timestep: torch.Tensor | float | int, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + guidance: torch.Tensor | None = None, + additional_t_cond: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + **kwargs: Any, +) -> CacheContext: + """ + Extract cache context for QwenImageTransformer2DModel. + + This is the ONLY Qwen-specific code needed for TeaCache support. + It encapsulates preprocessing, modulated input extraction, transformer execution, + and postprocessing logic. + + Args: + module: QwenImageTransformer2DModel instance + hidden_states: Input hidden states tensor + encoder_hidden_states: Text encoder outputs + encoder_hidden_states_mask: Mask for text encoder + timestep: Current diffusion timestep + img_shapes: Image shapes for position embedding + txt_seq_lens: Text sequence lengths + guidance: Optional guidance scale for CFG + additional_t_cond: Optional additional timestep conditioning + attention_kwargs: Additional attention arguments + **kwargs: Additional keyword arguments ignored by this extractor + + Returns: + CacheContext with all information needed for generic caching + """ + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: + raise ValueError("Module must have transformer_blocks") + + # ============================================================================ + # PREPROCESSING (Qwen-specific) + # ============================================================================ + hidden_states = module.img_in(hidden_states) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype) + encoder_hidden_states = module.txt_norm(encoder_hidden_states) + encoder_hidden_states = module.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + module.time_text_embed(timestep, hidden_states, additional_t_cond) + if guidance is None + else module.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) + ) + + image_rotary_emb = module.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + # ============================================================================ + # EXTRACT MODULATED INPUT (for cache decision) + # ============================================================================ + block = module.transformer_blocks[0] + img_mod_params = block.img_mod(temb) + img_mod1, _ = img_mod_params.chunk(2, dim=-1) + img_modulated, _ = block.img_norm1(hidden_states, img_mod1) + + # ============================================================================ + # DEFINE TRANSFORMER EXECUTION (Qwen-specific) + # ============================================================================ + def run_transformer_blocks(): + """Execute all Qwen transformer blocks.""" + h = hidden_states + e = encoder_hidden_states + encoder_mask = encoder_hidden_states_mask + hidden_states_mask = None # default + if module.parallel_config is not None and module.parallel_config.sequence_parallel_size > 1: + ctx = get_forward_context() + if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: + # Create mask for the full (padded) sequence + # valid positions = True, padding positions = False + batch_size = hidden_states.shape[0] + padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size + hidden_states_mask = torch.ones( + batch_size, + padded_seq_len, + dtype=torch.bool, + device=hidden_states.device, + ) + hidden_states_mask[:, ctx.sp_original_seq_len :] = False + + # if mask is all true, set it to None + if hidden_states_mask is not None and hidden_states_mask.all(): + hidden_states_mask = None + if encoder_mask is not None and encoder_mask.all(): + encoder_mask = None + for block in module.transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + encoder_hidden_states_mask=encoder_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + hidden_states_mask=hidden_states_mask, + ) + return (h, e) + + # ============================================================================ + # DEFINE POSTPROCESSING (Qwen-specific) + # ============================================================================ + return_dict = kwargs.get("return_dict", True) + + def postprocess(h): + """Apply Qwen-specific output postprocessing.""" + h = module.norm_out(h, temb) + output = module.proj_out(h) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + # ============================================================================ + # RETURN CONTEXT + # ============================================================================ + return CacheContext( + modulated_input=img_modulated, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) + + +def extract_bagel_context( + module: nn.Module, + x_t: torch.Tensor, + timestep: torch.Tensor | float | int, + packed_vae_token_indexes: torch.LongTensor, + packed_vae_position_ids: torch.LongTensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_indexes: torch.LongTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + key_values_lens: torch.IntTensor, + past_key_values: Any, + packed_key_value_indexes: torch.LongTensor, + **kwargs: Any, +) -> CacheContext: + """ + Extract cache context for Bagel model. + + Args: + module: Bagel instance + x_t: Latent image input + timestep: Current timestep + packed_vae_token_indexes: Indexes for VAE tokens in packed sequence + packed_vae_position_ids: Position IDs for VAE tokens + packed_text_ids: Text token IDs + packed_text_indexes: Indexes for text tokens in packed sequence + packed_indexes: Global indexes + packed_position_ids: Global position IDs + packed_seqlens: Sequence lengths + key_values_lens: KV cache lengths + past_key_values: KV cache + packed_key_value_indexes: KV cache indexes + **kwargs: Additional keyword arguments + + Returns: + CacheContext with all information needed for generic caching + """ + + # 1. Embed text + packed_text_embedding = module.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), module.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + # 2. Embed timestep + if not isinstance(timestep, torch.Tensor): + timestep = torch.tensor([timestep], device=x_t.device) + if timestep.dim() == 0: + timestep = timestep.unsqueeze(0) + + # 3. Embed image (x_t) + packed_pos_embed = module.latent_pos_embed(packed_vae_position_ids) + packed_timestep_embeds = module.time_embedder(timestep) + + x_t_emb = module.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed + if x_t_emb.dtype != packed_sequence.dtype: + x_t_emb = x_t_emb.to(packed_sequence.dtype) + + packed_sequence[packed_vae_token_indexes] = x_t_emb + + # Use the full packed sequence as modulated input to match hidden_states size + modulated_input = packed_sequence + + def run_transformer_blocks(): + extra_inputs = {} + if module.use_moe: + extra_inputs = { + "mode": "gen", + "packed_vae_token_indexes": packed_vae_token_indexes, + "packed_text_indexes": packed_text_indexes, + } + + output = module.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + return (output.packed_query_sequence,) + + def postprocess(h): + v_t = module.llm2vae(h) + v_t = v_t[packed_vae_token_indexes] + return v_t + + return CacheContext( + modulated_input=modulated_input, + hidden_states=packed_sequence, # Use full packed sequence + encoder_hidden_states=None, + temb=packed_timestep_embeds, # Approximate + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) + + +def extract_zimage_context( + module: nn.Module, + x: list[torch.Tensor], + t: torch.Tensor, + cap_feats: list[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + **kwargs: Any, +) -> CacheContext: + """ + Extract cache context for ZImageTransformer2DModel. + + This is the ONLY Z-Image-specific code needed for TeaCache support. + It encapsulates preprocessing, modulated input extraction, transformer execution, + and postprocessing logic. + + Args: + module: ZImageTransformer2DModel instance + x: List of image tensors per batch item + t: Timestep tensor + cap_feats: List of caption feature tensors per batch item + patch_size: Patch size for patchification (default: 2) + f_patch_size: Frame patch size (default: 1) + **kwargs: Additional keyword arguments ignored by this extractor + + Returns: + CacheContext with all information needed for generic caching + """ + from torch.nn.utils.rnn import pad_sequence + + if not hasattr(module, "layers") or len(module.layers) == 0: + raise ValueError("Module must have main transformer layers") + + bsz = len(x) + device = x[0].device + + # ============================================================================ + # PREPROCESSING (Z-Image specific) + # ============================================================================ + # Scale timestep and create timestep embedding + t_scaled = t * module.t_scale + adaln_input = module.t_embedder(t_scaled) + + # Patchify and embed inputs + ( + x_patches, + cap_feats_processed, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = module.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # Process image patches through embedder and noise refiner + x_item_seqlens = [len(_) for _ in x_patches] + x_max_item_seqlen = max(x_item_seqlens) + + x_embedded = torch.cat(x_patches, dim=0) + x_embedded = module.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_embedded) + + # Match adaln_input dtype to x_embedded + adaln_input = adaln_input.type_as(x_embedded) + + # Apply pad token + x_embedded[torch.cat(x_inner_pad_mask)] = module.x_pad_token + x_list = list(x_embedded.split(x_item_seqlens, dim=0)) + + # Compute rope embeddings for image patches + x_cos, x_sin = module.rope_embedder(torch.cat(x_pos_ids, dim=0)) + x_cos = list(x_cos.split(x_item_seqlens, dim=0)) + x_sin = list(x_sin.split(x_item_seqlens, dim=0)) + + # Pad sequences for batch processing + x_batched = pad_sequence(x_list, batch_first=True, padding_value=0.0) + x_cos_batched = pad_sequence(x_cos, batch_first=True, padding_value=0.0) + x_sin_batched = pad_sequence(x_sin, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + # Run noise refiner blocks + for layer in module.noise_refiner: + x_batched = layer(x_batched, x_attn_mask, x_cos_batched, x_sin_batched, adaln_input) + + # Process caption features through embedder and context refiner + cap_item_seqlens = [len(_) for _ in cap_feats_processed] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_embedded = torch.cat(cap_feats_processed, dim=0) + cap_embedded = module.cap_embedder(cap_embedded) + cap_embedded[torch.cat(cap_inner_pad_mask)] = module.cap_pad_token + cap_list = list(cap_embedded.split(cap_item_seqlens, dim=0)) + + # Compute rope embeddings for caption + cap_cos, cap_sin = module.rope_embedder(torch.cat(cap_pos_ids, dim=0)) + cap_cos = list(cap_cos.split(cap_item_seqlens, dim=0)) + cap_sin = list(cap_sin.split(cap_item_seqlens, dim=0)) + + # Pad sequences for batch processing + cap_batched = pad_sequence(cap_list, batch_first=True, padding_value=0.0) + cap_cos_batched = pad_sequence(cap_cos, batch_first=True, padding_value=0.0) + cap_sin_batched = pad_sequence(cap_sin, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + # Run context refiner blocks + for layer in module.context_refiner: + cap_batched = layer(cap_batched, cap_attn_mask, cap_cos_batched, cap_sin_batched) + + # Create unified sequence (image + caption) + unified_list = [] + unified_cos_list = [] + unified_sin_list = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified_list.append(torch.cat([x_batched[i][:x_len], cap_batched[i][:cap_len]])) + unified_cos_list.append(torch.cat([x_cos_batched[i][:x_len], cap_cos_batched[i][:cap_len]])) + unified_sin_list.append(torch.cat([x_sin_batched[i][:x_len], cap_sin_batched[i][:cap_len]])) + + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified_list, batch_first=True, padding_value=0.0) + unified_cos = pad_sequence(unified_cos_list, batch_first=True, padding_value=0.0) + unified_sin = pad_sequence(unified_sin_list, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + # ============================================================================ + # EXTRACT MODULATED INPUT (for cache decision) + # ============================================================================ + # Use the first main transformer block's modulation + # The main layers have modulation=True and process the unified sequence + block = module.layers[0] + # Get modulation parameters: scale_msa, gate_msa, scale_mlp, gate_mlp + mod_params = block.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + scale_msa = 1.0 + mod_params[0] + # Extract modulated input: normalized hidden states scaled by modulation + modulated_input = block.attention_norm1(unified) * scale_msa + + # ============================================================================ + # DEFINE TRANSFORMER EXECUTION (Z-Image specific) + # ============================================================================ + def run_transformer_blocks(): + """Execute all Z-Image main transformer blocks.""" + h = unified + for layer in module.layers: + h = layer(h, unified_attn_mask, unified_cos, unified_sin, adaln_input) + return (h,) + + # ============================================================================ + # DEFINE POSTPROCESSING (Z-Image specific) + # ============================================================================ + def postprocess(h): + """Apply Z-Image specific output postprocessing.""" + h = module.all_final_layer[f"{patch_size}-{f_patch_size}"](h, adaln_input) + h = list(h.unbind(dim=0)) + output = module.unpatchify(h, x_size, patch_size, f_patch_size) + return output, {} + + # ============================================================================ + # RETURN CONTEXT + # ============================================================================ + return CacheContext( + modulated_input=modulated_input, + hidden_states=unified, + encoder_hidden_states=None, # Z-Image uses unified sequence, no separate encoder states + temb=adaln_input, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + extra_states={ + "unified_attn_mask": unified_attn_mask, + "unified_cos": unified_cos, + "unified_sin": unified_sin, + "x_size": x_size, + "x_item_seqlens": x_item_seqlens, + "patch_size": patch_size, + "f_patch_size": f_patch_size, + }, + ) + + +# Registry for model-specific extractors +# Key: Transformer class name +# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext +# +# Note: Use the transformer class name as specified in pipelines as TeaCache hooks operate +# on the transformer module and multiple pipelines can share the same transformer. +EXTRACTOR_REGISTRY: dict[str, Callable] = { + "QwenImageTransformer2DModel": extract_qwen_context, + "Bagel": extract_bagel_context, + "ZImageTransformer2DModel": extract_zimage_context, + # Future models: + # "FluxTransformer2DModel": extract_flux_context, + # "CogVideoXTransformer3DModel": extract_cogvideox_context, +} + + +def register_extractor(transformer_cls_name: str, extractor_fn: Callable) -> None: + """ + Register a new extractor function for a model type. + + This allows extending TeaCache support to new models without modifying + the core TeaCache code. + + Args: + transformer_cls_name: Transformer model type identifier (class name or type string) + extractor_fn: Function with signature (module, *args, **kwargs) -> CacheContext + + Example: + >>> def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs): + ... # Preprocessing + ... temb = module.time_text_embed(timestep, guidance) + ... # Extract modulated input + ... modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb) + ... # Define execution + ... def run_blocks(): + ... h = hidden_states + ... for block in module.transformer_blocks: + ... h = block(h, temb=temb) + ... return (h,) + ... # Define postprocessing + ... def postprocess(h): + ... return module.proj_out(module.norm_out(h, temb)) + ... # Return context + ... return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess) + >>> register_extractor("FluxTransformer2DModel", extract_flux_context) + """ + EXTRACTOR_REGISTRY[transformer_cls_name] = extractor_fn + + +def get_extractor(transformer_cls_name: str) -> Callable: + """ + Get extractor function for given transformer class. + + This function looks up the extractor based on the exact transformer_cls_name string, + which should match the transformer type in the pipeline (i.e., pipeline.transformer.__class__.__name__). + + Args: + transformer_cls_name: Transformer class name (e.g., "QwenImageTransformer2DModel") + Must exactly match a key in EXTRACTOR_REGISTRY. + + Returns: + Extractor function with signature (module, *args, **kwargs) -> CacheContext + + Raises: + ValueError: If model type not found in registry + + Example: + >>> # Get extractor for QwenImageTransformer2DModel + >>> extractor = get_extractor("QwenImageTransformer2DModel") + >>> ctx = extractor(transformer, hidden_states, encoder_hidden_states, timestep, ...) + """ + # Direct lookup - no substring matching + if transformer_cls_name in EXTRACTOR_REGISTRY: + return EXTRACTOR_REGISTRY[transformer_cls_name] + + # No match found + available_types = list(EXTRACTOR_REGISTRY.keys()) + raise ValueError( + f"Unknown model type: '{transformer_cls_name}'. " + f"Available types: {available_types}\n" + f"To add support for a new model, use register_extractor() or add to EXTRACTOR_REGISTRY." + ) diff --git a/vllm_omni/diffusion/cache/teacache/hook.py b/vllm_omni/diffusion/cache/teacache/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..65f764c43bfc43f87b259fc1da90043de83cd435 --- /dev/null +++ b/vllm_omni/diffusion/cache/teacache/hook.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Hook-based TeaCache implementation for vLLM-Omni. + +This module implements a diffusers-style hook system that completely intercepts +the transformer forward pass, eliminating the need for any TeaCache-specific +code in model definitions. Model developers only need to add an extractor function +to support new models. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig +from vllm_omni.diffusion.cache.teacache.extractors import get_extractor +from vllm_omni.diffusion.cache.teacache.state import TeaCacheState +from vllm_omni.diffusion.distributed.parallel_state import ( + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.hooks import HookRegistry, ModelHook, StateManager + + +class TeaCacheHook(ModelHook): + """ + ModelHook implementing TeaCache for transformer models. + + This hook completely intercepts the transformer's forward pass and implements + adaptive caching based on timestep embedding similarity. It's model-agnostic + and supports multiple model types through extractor functions. + + Key features: + - Zero changes to model code + - CFG-aware with separate states for positive/negative branches + - CFG-parallel compatible: properly detects branch identity across ranks + - Model-specific polynomial rescaling + - Auto-detection of model types + + Attributes: + config: TeaCache configuration with thresholds and callbacks + rescale_func: Polynomial function for rescaling L1 distances + state_manager: Manages TeaCacheState across forward passes + extractor_fn: Model-specific function to extract modulated input + """ + + _HOOK_NAME = "teacache" + + def __init__(self, config: TeaCacheConfig): + """ + Initialize TeaCacheHook. + + Args: + config: TeaCache configuration object. + """ + super().__init__() + self.config = config + self.rescale_func = np.poly1d(config.coefficients) + self.state_manager = StateManager(TeaCacheState) + self.extractor_fn = None + self._forward_cnt = 0 + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + """ + Initialize hook with extractor from config transformer model type. + + Args: + module: The module to initialize the hook for. + + Returns: + The initialized module. + """ + # Get extractor function based on transformer_type from config + # transformer_type is the transformer class name (e.g., "QwenImageTransformer2DModel") + self.extractor_fn = get_extractor(self.config.transformer_type) + + # Set default context + self.state_manager.set_context("teacache") + + return module + + def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: + """ + Generic forward handler that works for ANY model. + + This method is completely model-agnostic. All model-specific logic + is encapsulated in the extractor function that returns a CacheContext. + + The extractor does: + - Model-specific preprocessing + - Extraction of modulated input for cache decision + - Providing transformer execution callable + - Providing postprocessing callable + + This hook does: + - CFG-aware state management + - Cache decision logic (generic) + - Residual caching and reuse + + Args: + module: Transformer module (any architecture) + *args: Positional arguments for model forward + **kwargs: Keyword arguments for model forward + + Returns: + Model output (format depends on model) + """ + # Get model-specific context from extractor + # The extractor encapsulates ALL model-specific logic + ctx = self.extractor_fn(module, *args, **kwargs) + + # ============================================================================ + # GENERIC CACHING LOGIC (works for all models) + # ============================================================================ + # Set context based on CFG branch for separate state tracking + # With CFG-parallel, each rank processes only one branch: + # - cfg_rank 0: positive branch + # - cfg_rank > 0: negative branch + # Without CFG-parallel, branches alternate within a single rank + if getattr(module, "do_true_cfg", False): + cfg_parallel_size = get_classifier_free_guidance_world_size() + if cfg_parallel_size > 1: + cfg_rank = get_classifier_free_guidance_rank() + cache_branch = "negative" if cfg_rank > 0 else "positive" + else: + # No CFG-parallel: use forward counter to alternate branches + cache_branch = "negative" if self._forward_cnt % 2 == 1 else "positive" + else: + cache_branch = "positive" + + context_name = f"teacache_{cache_branch}" + self.state_manager.set_context(context_name) + state = self.state_manager.get_state() + + # Decide whether to compute or cache based on modulated input similarity + should_compute = self._should_compute_full_transformer(state, ctx.modulated_input) + + if not should_compute and state.previous_residual is not None: + # ============================================================================ + # FAST PATH: Reuse cached residuals + # ============================================================================ + ctx.hidden_states = ctx.hidden_states + state.previous_residual + if state.previous_residual_encoder is not None and ctx.encoder_hidden_states is not None: + ctx.encoder_hidden_states = ctx.encoder_hidden_states + state.previous_residual_encoder + output = ctx.hidden_states + else: + # ============================================================================ + # SLOW PATH: Full transformer computation + # ============================================================================ + ori_hidden_states = ctx.hidden_states.clone() + ori_encoder_hidden_states = ( + ctx.encoder_hidden_states.clone() if ctx.encoder_hidden_states is not None else None + ) + + # Run transformer blocks using model-specific callable + outputs = ctx.run_transformer_blocks() + + # Update context with outputs + ctx.hidden_states = outputs[0] + if len(outputs) > 1 and ctx.encoder_hidden_states is not None: + ctx.encoder_hidden_states = outputs[1] + + # Cache residuals for next timestep + state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach() + if ori_encoder_hidden_states is not None: + state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach() + + output = ctx.hidden_states + + # Update state + state.previous_modulated_input = ctx.modulated_input.detach() + state.cnt += 1 + self._forward_cnt += 1 + + # ============================================================================ + # POSTPROCESSING (model-specific, via callable) + # ============================================================================ + return ctx.postprocess(output) + + def _should_compute_full_transformer(self, state: TeaCacheState, modulated_inp: torch.Tensor) -> bool: + """ + Determine whether to compute full transformer or reuse cached residual. + + This implements the core TeaCache algorithm: + 1. Always compute first timestep + 2. For intermediate steps: + - Compute relative L1 distance between current and previous modulated inputs + - Apply polynomial rescaling with model-specific coefficients + - Accumulate rescaled distances + - Compare to threshold: below = cache, above = compute + + Args: + state: Current TeaCacheState containing counters and cached values + modulated_inp: Modulated input extracted from first transformer block + + Returns: + True to compute full transformer, False to reuse cached residual + """ + # First timestep: always compute + if state.cnt == 0: + state.accumulated_rel_l1_distance = 0.0 + return True + + # Need previous input for comparison + if state.previous_modulated_input is None: + return True + + # Compute relative L1 distance between consecutive modulated inputs + rel_distance = ( + ( + (modulated_inp - state.previous_modulated_input).abs().mean() + / (state.previous_modulated_input.abs().mean() + 1e-8) + ) + .cpu() + .item() + ) + + # Apply model-specific polynomial rescaling + rescaled_distance = float(self.rescale_func(rel_distance)) + state.accumulated_rel_l1_distance += abs(rescaled_distance) + + # Decision: below threshold = cache, above = compute + if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: + return False # Use cache + else: + state.accumulated_rel_l1_distance = 0.0 # Reset accumulator + return True # Compute + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + """ + Reset all cached states for a new inference run. + + Args: + module: The module to reset state for. + + Returns: + The module with reset state. + """ + self.state_manager.reset() + self._forward_cnt = 0 + return module + + +def apply_teacache_hook(module: torch.nn.Module, config: TeaCacheConfig) -> None: + """ + Apply TeaCache optimization to a transformer module. + + This function registers a TeaCacheHook that completely intercepts the + module's forward pass, implementing adaptive caching without any changes + to the model code. + + Args: + module: Transformer model to optimize (e.g., QwenImageTransformer2DModel) + config: TeaCacheConfig specifying caching parameters + + Example: + >>> config = TeaCacheConfig( + ... rel_l1_thresh=0.2, + ... transformer_type="QwenImageTransformer2DModel" + ... ) + >>> apply_teacache_hook(transformer, config) + >>> # Transformer bound to the pipeline now uses TeaCache automatically, + ... # no code changes needed! + """ + registry = HookRegistry.get_or_create(module) + hook = TeaCacheHook(config) + registry.register_hook(TeaCacheHook._HOOK_NAME, hook) diff --git a/vllm_omni/diffusion/cache/teacache/state.py b/vllm_omni/diffusion/cache/teacache/state.py new file mode 100644 index 0000000000000000000000000000000000000000..a6429e54019539fc80ca58e3238b06ca0c2bc03e --- /dev/null +++ b/vllm_omni/diffusion/cache/teacache/state.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +TeaCache state management. + +This module manages the state for TeaCache hooks across diffusion timesteps. +""" + +import torch + + +class TeaCacheState: + """ + State management for TeaCache hook. + + Tracks caching state across diffusion timesteps, managing counters, + accumulated distances, and cached residuals for the TeaCache algorithm. + """ + + def __init__(self): + """Initialize empty TeaCache state.""" + # Timestep tracking + self.cnt = 0 + + # Caching state + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input: torch.Tensor | None = None + self.previous_residual: torch.Tensor | None = None + self.previous_residual_encoder: torch.Tensor | None = None + + def reset(self) -> None: + """Reset all state variables for a new inference run.""" + self.cnt = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + self.previous_residual_encoder = None diff --git a/vllm_omni/diffusion/compile.py b/vllm_omni/diffusion/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..66a4675f3cea3e7de620d93560a4bf6109e45926 --- /dev/null +++ b/vllm_omni/diffusion/compile.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch.nn as nn +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def regionally_compile(model: nn.Module, *compile_args: Any, **compile_kwargs: Any) -> nn.Module: + """ + Apply regional compilation to a PyTorch model. + + Args: + model: The PyTorch model instance to compile + *compile_args: Positional arguments forwarded to torch.compile + **compile_kwargs: Keyword arguments forwarded to torch.compile + + Returns: + The same model instance (modified in-place) + """ + # Get the list of repeated blocks from the model + repeated_blocks = getattr(model, "_repeated_blocks", None) + + if not repeated_blocks: + logger.warning("Regional compilation skipped because the model does not define `_repeated_blocks`.") + return model + + # Check if we have modules with the specified class names + has_compiled_region = False + for submod in model.modules(): + if submod.__class__.__name__ in repeated_blocks: + # Compile this submodule + submod.compile(*compile_args, **compile_kwargs) + has_compiled_region = True + + if not has_compiled_region: + logger.warning(f"Regional compilation skipped because {repeated_blocks} classes are not found in the model.") + + return model diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py new file mode 100644 index 0000000000000000000000000000000000000000..f884fb6f177c5b227682da117ae01330d0fe8b6f --- /dev/null +++ b/vllm_omni/diffusion/data.py @@ -0,0 +1,519 @@ +# adapted from sglang and fastvideo +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum +import os +import random +from collections.abc import Callable +from dataclasses import dataclass, field, fields +from typing import Any + +import torch +from pydantic import model_validator +from typing_extensions import Self +from vllm.config.utils import config +from vllm.logger import init_logger + +from vllm_omni.diffusion.utils.network_utils import is_port_available + +logger = init_logger(__name__) + + +@config +@dataclass +class DiffusionParallelConfig: + """Configuration for diffusion model distributed execution.""" + + pipeline_parallel_size: int = 1 + """Number of pipeline parallel stages.""" + + data_parallel_size: int = 1 + """Number of data parallel groups.""" + + tensor_parallel_size: int = 1 + """Number of tensor parallel groups.""" + + sequence_parallel_size: int | None = None + """Number of sequence parallel groups. sequence_parallel_size = ring_degree * ulysses_degree""" + + ulysses_degree: int = 1 + """Number of GPUs used for ulysses sequence parallelism.""" + + ring_degree: int = 1 + """Number of GPUs used for ring sequence parallelism.""" + + cfg_parallel_size: int = 1 + """Number of Classifier Free Guidance (CFG) parallel groups.""" + + @model_validator(mode="after") + def _validate_parallel_config(self) -> Self: + """Validates the config relationships among the parallel strategies.""" + assert self.pipeline_parallel_size > 0, "Pipeline parallel size must be > 0" + assert self.data_parallel_size > 0, "Data parallel size must be > 0" + assert self.tensor_parallel_size > 0, "Tensor parallel size must be > 0" + assert self.sequence_parallel_size > 0, "Sequence parallel size must be > 0" + assert self.ulysses_degree > 0, "Ulysses degree must be > 0" + assert self.ring_degree > 0, "Ring degree must be > 0" + assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" + assert self.cfg_parallel_size in [1, 2], f"CFG parallel size must be 1 or 2, but got {self.cfg_parallel_size}" + assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( + "Sequence parallel size must be equal to the product of ulysses degree and ring degree," + f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" + ) + return self + + def __post_init__(self) -> None: + if self.sequence_parallel_size is None: + self.sequence_parallel_size = self.ulysses_degree * self.ring_degree + self.world_size = ( + self.pipeline_parallel_size + * self.data_parallel_size + * self.tensor_parallel_size + * self.ulysses_degree + * self.ring_degree + * self.cfg_parallel_size + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DiffusionParallelConfig": + """ + Create DiffusionParallelConfig from a dictionary. + + Args: + data: Dictionary containing parallel configuration parameters + + Returns: + DiffusionParallelConfig instance with parameters set from dict + """ + if not isinstance(data, dict): + raise TypeError(f"Expected parallel config dict, got {type(data)!r}") + return cls(**data) + + +@dataclass +class TransformerConfig: + """Container for raw transformer configuration dictionaries.""" + + params: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "TransformerConfig": + if not isinstance(data, dict): + raise TypeError(f"Expected transformer config dict, got {type(data)!r}") + return cls(params=dict(data)) + + def to_dict(self) -> dict[str, Any]: + return dict(self.params) + + def get(self, key: str, default: Any | None = None) -> Any: + return self.params.get(key, default) + + def __getattr__(self, item: str) -> Any: + params = object.__getattribute__(self, "params") + try: + return params[item] + except KeyError as exc: + raise AttributeError(item) from exc + + +@dataclass +class DiffusionCacheConfig: + """ + Configuration for cache adapters (TeaCache, cache-dit, etc.). + + This dataclass provides a unified interface for cache configuration parameters. + It can be initialized from a dictionary and accessed via attributes. + + Common parameters: + - TeaCache: rel_l1_thresh, coefficients (optional) + - cache-dit: Fn_compute_blocks, Bn_compute_blocks, max_warmup_steps, + residual_diff_threshold, enable_taylorseer, taylorseer_order, + scm_steps_mask_policy, scm_steps_policy + + Example: + >>> # From dict (user-facing API) - partial config uses defaults for missing keys + >>> config = DiffusionCacheConfig.from_dict({"rel_l1_thresh": 0.3}) + >>> # Access via attribute + >>> print(config.rel_l1_thresh) # 0.3 (from dict) + >>> print(config.Fn_compute_blocks) # 8 (default) + >>> # Empty dict uses all defaults + >>> default_config = DiffusionCacheConfig.from_dict({}) + >>> print(default_config.rel_l1_thresh) # 0.2 (default) + """ + + # TeaCache parameters [tea_cache only] + # Default: 0.2 provides ~1.5x speedup with minimal quality loss (optimal balance) + rel_l1_thresh: float = 0.2 + coefficients: list[float] | None = None # Uses model-specific defaults if None + + # cache-dit parameters [cache-dit only] + # Default: 1 forward compute block (optimized for single-transformer models) + # Use 1 as default instead of cache-dit's 8, optimized for single-transformer models + # This provides better performance while maintaining quality for most use cases + Fn_compute_blocks: int = 1 + # Default: 0 backward compute blocks (no fusion by default) + Bn_compute_blocks: int = 0 + # Default: 4 warmup steps (optimized for few-step distilled models like Z-Image with 8 steps) + # Use 4 as default warmup steps instead of 8 in cache-dit, making DBCache work + # for few-step distilled models (e.g., Z-Image with 8 steps) + max_warmup_steps: int = 4 + # Default: -1 (unlimited cached steps) - DBCache disables caching when previous cached steps exceed this value + # to prevent precision degradation. Set to -1 for unlimited caching (cache-dit default). + max_cached_steps: int = -1 + # Default: 0.24 residual difference threshold (higher for more aggressive caching) + # Use a relatively higher residual diff threshold (0.24) as default to allow more + # aggressive caching. This is safe because we have max_continuous_cached_steps limit. + # Without this limit, a lower threshold like 0.12 would be needed. + residual_diff_threshold: float = 0.24 + # Default: Limit consecutive cached steps to 3 to prevent precision degradation + # This allows us to use a higher residual_diff_threshold for more aggressive caching + max_continuous_cached_steps: int = 3 + # Default: Disable TaylorSeer (not suitable for few-step distilled models) + # TaylorSeer is not suitable for few-step distilled models, so we disable it by default. + # References: + # - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers + # - Forecast then Calibrate: Feature Caching as ODE for Efficient Diffusion Transformers + enable_taylorseer: bool = False + # Default: 1st order TaylorSeer polynomial + taylorseer_order: int = 1 + # Default: None SCM mask policy (disabled by default) + scm_steps_mask_policy: str | None = None + # Default: "dynamic" steps policy for adaptive caching + scm_steps_policy: str = "dynamic" + # Used by cache-dit for scm mask generation. If this value changes during inference, + # we will re-generate the scm mask and refresh the cache context. + num_inference_steps: int | None = None + + # Additional parameters that may be passed but not explicitly defined + _extra_params: dict[str, Any] = field(default_factory=dict, repr=False) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DiffusionCacheConfig": + """ + Create DiffusionCacheConfig from a dictionary. + + Args: + data: Dictionary containing cache configuration parameters + + Returns: + DiffusionCacheConfig instance with parameters set from dict + """ + if not isinstance(data, dict): + raise TypeError(f"Expected cache config dict, got {type(data)!r}") + + # Get all dataclass field names automatically + field_names = {f.name for f in fields(cls)} + + # Extract parameters that match dataclass fields (excluding private fields) + known_params = {k: v for k, v in data.items() if k in field_names and not k.startswith("_")} + + # Store extra parameters + extra_params = {k: v for k, v in data.items() if k not in field_names} + + # Create instance with known params (missing ones will use defaults) + # Then update _extra_params after creation since it's a private field + instance = cls(**known_params, _extra_params=extra_params) + return instance + + def __getattr__(self, item: str) -> Any: + """ + Allow access to extra parameters via attribute access. + + This enables accessing parameters that weren't explicitly defined + in the dataclass fields but were passed in the dict. + """ + if item == "_extra_params" or item.startswith("_"): + return object.__getattribute__(self, item) + + extra = object.__getattribute__(self, "_extra_params") + if item in extra: + return extra[item] + + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + + +@dataclass +class OmniDiffusionConfig: + # Model and path configuration (for convenience) + model: str | None = None + + model_class_name: str | None = None + + dtype: torch.dtype = torch.bfloat16 + + tf_model_config: TransformerConfig = field(default_factory=TransformerConfig) + + # Attention + attention_backend: str | None = None + + # Running mode + # mode: ExecutionMode = ExecutionMode.INFERENCE + + # Workload type + # workload_type: WorkloadType = WorkloadType.T2V + + # Cache strategy (legacy) + cache_strategy: str = "none" + parallel_config: DiffusionParallelConfig = field(default_factory=DiffusionParallelConfig) + + # Cache backend configuration (NEW) + cache_backend: str = "none" # "tea_cache", "deep_cache", etc. + cache_config: DiffusionCacheConfig | dict[str, Any] = field(default_factory=dict) + enable_cache_dit_summary: bool = False + + # Distributed executor backend + distributed_executor_backend: str = "mp" + nccl_port: int | None = None + + # HuggingFace specific parameters + trust_remote_code: bool = False + revision: str | None = None + + num_gpus: int | None = None + + hsdp_replicate_dim: int = 1 + hsdp_shard_dim: int = -1 + dist_timeout: int | None = None # timeout for torch.distributed + + # pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False) + + # LoRA parameters + lora_path: str | None = None + lora_scale: float = 1.0 + max_cpu_loras: int | None = None + + output_type: str = "pil" + + # CPU offload parameters + # When enabled, DiT and encoders swap GPU access (mutual exclusion): + # - Text encoders run on GPU while DiT is on CPU + # - DiT runs on GPU while encoders are on CPU + enable_cpu_offload: bool = False + + # Layer-wise offloading (block-level offloading) parameters + enable_layerwise_offload: bool = False + # Number of transformer blocks ready for computation to keep on GPU + layerwise_num_gpu_layers: int = 1 + + use_fsdp_inference: bool = False + pin_cpu_memory: bool = True # Use pinned memory for faster transfers when offloading + + # VAE memory optimization parameters + vae_use_slicing: bool = False + vae_use_tiling: bool = False + + # STA (Sliding Tile Attention) parameters + mask_strategy_file_path: str | None = None + # STA_mode: STA_Mode = STA_Mode.STA_INFERENCE + skip_time_steps: int = 15 + + # Compilation + enforce_eager: bool = False + + # Enable sleep mode + enable_sleep_mode: bool = False + + disable_autocast: bool = False + + # VSA parameters + VSA_sparsity: float = 0.0 # inference/validation sparsity + + # V-MoBA parameters + moba_config_path: str | None = None + # moba_config: dict[str, Any] = field(default_factory=dict) + + # Master port for distributed inference + # TODO: do not hard code + master_port: int | None = None + + # http server endpoint config, would be ignored in local mode + host: str | None = None + port: int | None = None + + scheduler_port: int = 5555 + + # Stage verification + enable_stage_verification: bool = True + + # Prompt text file for batch processing + prompt_file_path: str | None = None + + # model paths for correct deallocation + model_paths: dict[str, str] = field(default_factory=dict) + model_loaded: dict[str, bool] = field( + default_factory=lambda: { + "transformer": True, + "vae": True, + } + ) + override_transformer_cls_name: str | None = None + + # # DMD parameters + # dmd_denoising_steps: List[int] | None = field(default=None) + + # MoE parameters used by Wan2.2 + boundary_ratio: float | None = None + # Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p) + flow_shift: float | None = None + + # support multi images input + supports_multimodal_inputs: bool = False + + # Logging + log_level: str = "info" + + # Omni configuration (injected from stage config) + omni_kv_config: dict[str, Any] = field(default_factory=dict) + + def settle_port(self, port: int, port_inc: int = 42, max_attempts: int = 100) -> int: + """ + Find an available port with retry logic. + + Args: + port: Initial port to check + port_inc: Port increment for each attempt + max_attempts: Maximum number of attempts to find an available port + + Returns: + An available port number + + Raises: + RuntimeError: If no available port is found after max_attempts + """ + attempts = 0 + original_port = port + + while attempts < max_attempts: + if is_port_available(port): + if attempts > 0: + logger.info(f"Port {original_port} was unavailable, using port {port} instead") + return port + + attempts += 1 + if port < 60000: + port += port_inc + else: + # Wrap around with randomization to avoid collision + port = 5000 + random.randint(0, 1000) + + raise RuntimeError( + f"Failed to find available port after {max_attempts} attempts (started from port {original_port})" + ) + + def __post_init__(self): + # TODO: remove hard code + initial_master_port = (self.master_port or 30005) + random.randint(0, 100) + self.master_port = self.settle_port(initial_master_port, 37) + + # Convert parallel_config dict to DiffusionParallelConfig if needed + # This must be done before accessing parallel_config.world_size + if isinstance(self.parallel_config, dict): + self.parallel_config = DiffusionParallelConfig.from_dict(self.parallel_config) + elif not isinstance(self.parallel_config, DiffusionParallelConfig): + # If it's neither dict nor DiffusionParallelConfig, use default config + self.parallel_config = DiffusionParallelConfig() + + if self.num_gpus is None: + if self.parallel_config is not None: + self.num_gpus = self.parallel_config.world_size + else: + self.num_gpus = 1 + + if self.num_gpus < self.parallel_config.world_size: + raise ValueError( + f"num_gpus ({self.num_gpus}) < parallel_config.world_size ({self.parallel_config.world_size})" + ) + + # Convert string dtype to torch.dtype if needed + if isinstance(self.dtype, str): + dtype_map = { + "auto": torch.bfloat16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float16": torch.float16, + "fp16": torch.float16, + "half": torch.float16, + "float32": torch.float32, + "fp32": torch.float32, + "float": torch.float32, + } + dtype_lower = self.dtype.lower() + if dtype_lower in dtype_map: + self.dtype = dtype_map[dtype_lower] + else: + logger.warning(f"Unknown dtype string '{self.dtype}', defaulting to bfloat16") + self.dtype = torch.bfloat16 + + # Convert cache_config dict to DiffusionCacheConfig if needed + if isinstance(self.cache_config, dict): + self.cache_config = DiffusionCacheConfig.from_dict(self.cache_config) + elif not isinstance(self.cache_config, DiffusionCacheConfig): + # If it's neither dict nor DiffusionCacheConfig, convert to empty config + self.cache_config = DiffusionCacheConfig() + + if self.max_cpu_loras is None: + self.max_cpu_loras = 1 + elif self.max_cpu_loras < 1: + raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") + + def update_multimodal_support(self) -> None: + self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} + + @classmethod + def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": + # Backwards-compatibility: older callers may use a diffusion-specific + # "static_lora_scale" kwarg. Normalize it to the canonical "lora_scale" + # before constructing the dataclass to avoid TypeError on unknown fields. + if "static_lora_scale" in kwargs: + if "lora_scale" not in kwargs: + kwargs["lora_scale"] = kwargs["static_lora_scale"] + kwargs.pop("static_lora_scale", None) + + # Check environment variable as fallback for cache_backend + # Support both old DIFFUSION_CACHE_ADAPTER and new DIFFUSION_CACHE_BACKEND for backwards compatibility + if "cache_backend" not in kwargs: + cache_backend = os.environ.get("DIFFUSION_CACHE_BACKEND") or os.environ.get("DIFFUSION_CACHE_ADAPTER") + kwargs["cache_backend"] = cache_backend.lower() if cache_backend else "none" + + # Filter kwargs to only include valid fields + valid_fields = {f.name for f in fields(cls)} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields} + + return cls(**filtered_kwargs) + + +@dataclass +class DiffusionOutput: + """ + Final output (after pipeline completion) + """ + + output: torch.Tensor | None = None + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_latents: torch.Tensor | None = None + trajectory_decoded: list[torch.Tensor] | None = None + error: str | None = None + + post_process_func: Callable[..., Any] | None = None + + # logged timings info, directly from Req.timings + # timings: Optional["RequestTimings"] = None + + +class AttentionBackendEnum(enum.Enum): + FA = enum.auto() + SLIDING_TILE_ATTN = enum.auto() + TORCH_SDPA = enum.auto() + SAGE_ATTN = enum.auto() + SAGE_ATTN_THREE = enum.auto() + VIDEO_SPARSE_ATTN = enum.auto() + VMOBA_ATTN = enum.auto() + AITER = enum.auto() + NO_ATTENTION = enum.auto() + + def __str__(self): + return self.name.lower() + + +# Special message broadcast via scheduler queues to signal worker shutdown. +SHUTDOWN_MESSAGE = {"type": "shutdown"} diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..8e19f124266d1988436bdba7fe86b29aaf005f8e --- /dev/null +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time +from collections.abc import Iterable +from typing import Any + +import PIL.Image +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.executor.abstract import DiffusionExecutor +from vllm_omni.diffusion.registry import ( + DiffusionModelRegistry, + get_diffusion_post_process_func, + get_diffusion_pre_process_func, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + + +def supports_image_input(model_class_name: str) -> bool: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if model_cls is None: + return False + return bool(getattr(model_cls, "support_image_input", False)) + + +def image_color_format(model_class_name: str) -> str: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + return getattr(model_cls, "color_format", "RGB") + + +def supports_audio_output(model_class_name: str) -> bool: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if model_cls is None: + return False + return bool(getattr(model_cls, "support_audio_output", False)) + + +class DiffusionEngine: + """The diffusion engine for vLLM-Omni diffusion models.""" + + def __init__(self, od_config: OmniDiffusionConfig): + """Initialize the diffusion engine. + + Args: + config: The configuration for the diffusion engine. + """ + self.od_config = od_config + + self.post_process_func = get_diffusion_post_process_func(od_config) + self.pre_process_func = get_diffusion_pre_process_func(od_config) + + executor_class = DiffusionExecutor.get_class(od_config) + self.executor = executor_class(od_config) + + try: + self._dummy_run() + except Exception as e: + logger.error(f"Dummy run failed: {e}") + self.close() + raise e + + def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: + # Apply pre-processing if available + if self.pre_process_func is not None: + preprocess_start_time = time.time() + request = self.pre_process_func(request) + preprocess_time = time.time() - preprocess_start_time + logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds") + + output = self.add_req_and_wait_for_response(request) + if output.error: + raise Exception(f"{output.error}") + logger.info("Generation completed successfully.") + + if output.output is None: + logger.warning("Output is None, returning empty OmniRequestOutput") + return [ + OmniRequestOutput.from_diffusion( + request_id=request.request_ids[i] if i < len(request.request_ids) else "", + images=[], + prompt=prompt, + metrics={}, + latents=None, + ) + for i, prompt in enumerate(request.prompts) + ] + + postprocess_start_time = time.time() + outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output + postprocess_time = time.time() - postprocess_start_time + logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds") + + # Convert to OmniRequestOutput format + # Ensure outputs is a list + if not isinstance(outputs, list): + outputs = [outputs] if outputs is not None else [] + + # Handle single request or multiple requests + if len(request.prompts) == 1: + # Single request: return single OmniRequestOutput + prompt = request.prompts[0] + request_id = request.request_ids[0] if request.request_ids else "" + + metrics = {} + if output.trajectory_timesteps is not None: + metrics["trajectory_timesteps"] = output.trajectory_timesteps + + if supports_audio_output(self.od_config.model_class_name): + audio_payload = outputs[0] if len(outputs) == 1 else outputs + return [ + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + multimodal_output={"audio": audio_payload}, + final_output_type="audio", + ), + ] + else: + return [ + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=outputs, + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + ), + ] + else: + # Multiple requests: return list of OmniRequestOutput + # Split images based on num_outputs_per_prompt for each request + results = [] + output_idx = 0 + + for i, prompt in enumerate(request.prompts): + request_id = request.request_ids[i] if i < len(request.request_ids) else "" + + # Get images for this request + num_outputs = request.sampling_params.num_outputs_per_prompt + request_outputs = outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else [] + output_idx += num_outputs + + metrics = {} + if output.trajectory_timesteps is not None: + metrics["trajectory_timesteps"] = output.trajectory_timesteps + + if supports_audio_output(self.od_config.model_class_name): + audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs + results.append( + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + multimodal_output={"audio": audio_payload}, + final_output_type="audio", + ) + ) + else: + results.append( + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=request_outputs, + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + ) + ) + + return results + + @staticmethod + def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": + """Factory method to create a DiffusionEngine instance. + + Args: + config: The configuration for the diffusion engine. + + Returns: + An instance of DiffusionEngine. + """ + return DiffusionEngine(config) + + def add_req_and_wait_for_response(self, request: OmniDiffusionRequest): + return self.executor.add_req(request) + + def start_profile(self, trace_filename: str | None = None) -> None: + """ + Start torch profiling on all diffusion workers. + + Creates a directory (if needed) and sets up a base filename template + for per-rank profiler traces (typically saved as <template>_rank<N>.json). + + Args: + trace_filename: Optional base filename (without extension or rank suffix). + If None, generates one using current timestamp. + """ + if trace_filename is None: + trace_filename = f"stage_0_diffusion_{int(time.time())}_rank" + + trace_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") + + # Expand ~ and ~user, then make absolute (robust against cwd changes) + trace_dir = os.path.expanduser(trace_dir) + trace_dir = os.path.abspath(trace_dir) + + try: + os.makedirs(trace_dir, exist_ok=True) + except OSError as exc: + logger.error(f"Failed to create profiler directory {trace_dir}: {exc}") + raise + + # Build final template path (without rank or extension — torch.profiler appends those) + full_template = os.path.join(trace_dir, trace_filename) + + expected_pattern = f"{full_template}*.json" + logger.info(f"Starting diffusion profiling → {expected_pattern}") + + # Also log the absolute directory once (useful in multi-node or containers) + logger.debug(f"Profiler output directory: {trace_dir}") + + # Propagate to all workers + try: + self.collective_rpc(method="start_profile", args=(full_template,)) + except Exception as e: + logger.error("Failed to start profiling on workers", exc_info=True) + raise RuntimeError(f"Could not start profiler: {e}") from e + + def stop_profile(self) -> dict: + """ + Stop profiling on all workers and collect the final trace/table paths. + + The worker (torch_profiler.py) now handles trace export, compression to .gz, + and deletion of the original .json file. This method only collects and + reports the paths returned by the workers. + + Returns: + dict with keys: + - "traces": list of final trace file paths (usually .json.gz) + - "tables": list of table strings (one per rank) + """ + logger.info("Stopping diffusion profiling and collecting results...") + + try: + # Give worker enough time — export + compression + table can be slow + results = self.collective_rpc(method="stop_profile", timeout=600) + except Exception: + logger.error("Failed to stop profiling on workers", exc_info=True) + return {"traces": [], "tables": []} + + output_files = {"traces": [], "tables": []} + successful_traces = 0 + + if not results: + logger.warning("No profiling results returned from any rank") + return output_files + + for rank, res in enumerate(results): + if not isinstance(res, dict): + logger.warning(f"Rank {rank}: invalid result format (got {type(res)})") + continue + + # 1. Trace file — should be .json.gz if compression succeeded + trace_path = res.get("trace") + if trace_path: + # We trust the worker — it created/compressed the file + logger.info(f"[Rank {rank}] Final trace: {trace_path}") + output_files["traces"].append(trace_path) + successful_traces += 1 + + # Optional: warn if path looks suspicious (e.g. still .json) + if not trace_path.endswith((".json.gz", ".json")): + logger.warning(f"Rank {rank}: unusual trace path extension: {trace_path}") + + # 2. Table file — plain text + table = res.get("table") + if table: + output_files["tables"].append(table) + + # Final summary logging + num_ranks = len(results) + if successful_traces > 0: + final_paths_str = ", ".join(output_files["traces"][:3]) + if len(output_files["traces"]) > 3: + final_paths_str += f" ... (+{len(output_files['traces']) - 3} more)" + + logger.info( + f"Profiling stopped. Collected {successful_traces} trace file(s) " + f"from {num_ranks} rank(s). " + f"Final trace paths: {final_paths_str}" + ) + elif output_files["traces"]: + logger.info( + f"Profiling stopped but no traces were successfully collected. " + f"Reported paths: {', '.join(output_files['traces'][:3])}" + f"{' ...' if len(output_files['traces']) > 3 else ''}" + ) + else: + logger.info("Profiling stopped — no trace files were collected from any rank.") + + if output_files["tables"]: + logger.debug(f"Collected {len(output_files['tables'])} profiling table(s)") + + return output_files + + def _dummy_run(self): + """A dummy run to warm up the model.""" + num_inference_steps = 1 + height = 1024 + width = 1024 + if supports_image_input(self.od_config.model_class_name): + # Provide a dummy image input if the model supports it + color_format = image_color_format(self.od_config.model_class_name) + dummy_image = PIL.Image.new(color_format, (width, height)) + else: + dummy_image = None + prompt: OmniTextPrompt = {"prompt": "dummy run", "multi_modal_data": {"image": dummy_image}} + req = OmniDiffusionRequest( + prompts=[prompt], + sampling_params=OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + num_outputs_per_prompt=1, + ), + ) + logger.info("dummy run to warm up the model") + request = self.pre_process_func(req) if self.pre_process_func is not None else req + self.add_req_and_wait_for_response(request) + + def collective_rpc( + self, + method: str, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + unique_reply_rank: int | None = None, + ) -> Any: + """Call a method on worker processes and get results immediately. + + Args: + method: The method name (str) to execute on workers + timeout: Optional timeout in seconds + args: Positional arguments for the method + kwargs: Keyword arguments for the method + unique_reply_rank: If set, only get reply from this rank + + Returns: + Single result if unique_reply_rank is provided, otherwise list of results + """ + assert isinstance(method, str), "Only string method names are supported for now" + return self.executor.collective_rpc( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, + unique_reply_rank=unique_reply_rank, + ) + + def close(self) -> None: + if hasattr(self, "executor"): + self.executor.shutdown() + + def abort(self, request_id: str | Iterable[str]) -> None: + # TODO implement it + logger.warning("DiffusionEngine abort is not implemented yet") + pass diff --git a/vllm_omni/diffusion/distributed/__init__.py b/vllm_omni/diffusion/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c791ae1be64795e9fdc2f8a3b21817be702416ec --- /dev/null +++ b/vllm_omni/diffusion/distributed/__init__.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Distributed utilities for vLLM-Omni diffusion models.""" + +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelConfig, + SequenceParallelInput, + SequenceParallelModelPlan, + SequenceParallelOutput, + SequenceParallelPartialInput, + get_sp_plan_from_model, + validate_sp_plan, +) +from vllm_omni.diffusion.distributed.sp_sharding import ( + ShardingValidator, + get_sharding_validator, + sp_gather, + sp_shard, + sp_shard_with_padding, +) + +__all__ = [ + # Config + "SequenceParallelConfig", + # Plan types + "SequenceParallelInput", + "SequenceParallelOutput", + "SequenceParallelPartialInput", + "SequenceParallelModelPlan", + "validate_sp_plan", + "get_sp_plan_from_model", + # Sharding utilities + "sp_shard", + "sp_gather", + "sp_shard_with_padding", + "ShardingValidator", + "get_sharding_validator", +] diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..9f86bce228b00b6477bfbcba5a76a3e7da6ff587 --- /dev/null +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Base pipeline class for Diffusion models with shared CFG functionality. +""" + +from abc import ABCMeta +from typing import Any + +import torch + +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) + + +class CFGParallelMixin(metaclass=ABCMeta): + """ + Base Mixin class for Diffusion pipelines providing shared CFG methods. + + All pipelines should inherit from this class to reuse + classifier-free guidance logic. + """ + + def predict_noise_maybe_with_cfg( + self, + do_true_cfg: bool, + true_cfg_scale: float, + positive_kwargs: dict[str, Any], + negative_kwargs: dict[str, Any] | None, + cfg_normalize: bool = True, + output_slice: int | None = None, + ) -> torch.Tensor | None: + """ + Predict noise with optional classifier-free guidance. + + Args: + do_true_cfg: Whether to apply CFG + true_cfg_scale: CFG scale factor + positive_kwargs: Kwargs for positive/conditional prediction + negative_kwargs: Kwargs for negative/unconditional prediction + cfg_normalize: Whether to normalize CFG output (default: True) + output_slice: If set, slice output to [:, :output_slice] for image editing + + Returns: + Predicted noise tensor (only valid on rank 0 in CFG parallel mode) + """ + if do_true_cfg: + # Automatically detect CFG parallel configuration + cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_rank == 0: + local_pred = self.predict_noise(**positive_kwargs) + else: + local_pred = self.predict_noise(**negative_kwargs) + + # Slice output for image editing pipelines (remove condition latents) + if output_slice is not None: + local_pred = local_pred[:, :output_slice] + + gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + + if cfg_rank == 0: + noise_pred = gathered[0] + neg_noise_pred = gathered[1] + noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) + return noise_pred + else: + return None + else: + # Sequential CFG: compute both positive and negative + positive_noise_pred = self.predict_noise(**positive_kwargs) + negative_noise_pred = self.predict_noise(**negative_kwargs) + + # Slice output for image editing pipelines + if output_slice is not None: + positive_noise_pred = positive_noise_pred[:, :output_slice] + negative_noise_pred = negative_noise_pred[:, :output_slice] + + noise_pred = self.combine_cfg_noise( + positive_noise_pred, negative_noise_pred, true_cfg_scale, cfg_normalize + ) + return noise_pred + else: + # No CFG: only compute positive/conditional prediction + pred = self.predict_noise(**positive_kwargs) + if output_slice is not None: + pred = pred[:, :output_slice] + return pred + + def cfg_normalize_function(self, noise_pred: torch.Tensor, comb_pred: torch.Tensor) -> torch.Tensor: + """ + Normalize the combined noise prediction. + + Args: + noise_pred: positive noise prediction + comb_pred: combined noise prediction after CFG + + Returns: + Normalized noise prediction tensor + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + return noise_pred + + def combine_cfg_noise( + self, noise_pred: torch.Tensor, neg_noise_pred: torch.Tensor, true_cfg_scale: float, cfg_normalize: bool = False + ) -> torch.Tensor: + """ + Combine conditional and unconditional noise predictions with CFG. + + Args: + noise_pred: Conditional noise prediction + neg_noise_pred: Unconditional noise prediction + true_cfg_scale: CFG scale factor + cfg_normalize: Whether to normalize the combined prediction (default: False) + + Returns: + Combined noise prediction tensor + """ + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + if cfg_normalize: + noise_pred = self.cfg_normalize_function(noise_pred, comb_pred) + else: + noise_pred = comb_pred + + return noise_pred + + def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Subclasses should override this if they need custom behavior, + but the default implementation calls self.transformer. + """ + return self.transformer(*args, **kwargs)[0] + + def diffuse( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """ + Diffusion loop with optional classifier-free guidance. + + Subclasses MUST implement this method to define the complete + diffusion/denoising loop for their specific model. + + Typical implementation pattern: + ```python + def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, ...): + for t in timesteps: + # Prepare kwargs for positive and negative predictions + positive_kwargs = {...} + negative_kwargs = {...} + + # Predict noise with automatic CFG handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=self.guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + ) + + # Step scheduler with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg=True + ) + + return latents + ``` + """ + raise NotImplementedError("Subclasses must implement diffuse") + + def scheduler_step(self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """ + Step the scheduler. + + Args: + noise_pred: Predicted noise + t: Current timestep + latents: Current latents + + Returns: + Updated latents after scheduler step + """ + return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + def scheduler_step_maybe_with_cfg( + self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor, do_true_cfg: bool + ) -> torch.Tensor: + """ + Step the scheduler with (maybe) automatic CFG parallel synchronization. + + In CFG parallel mode, only rank 0 computes the scheduler step, + then broadcasts the result to other ranks. + + Args: + noise_pred: Predicted noise (only valid on rank 0 in CFG parallel) + t: Current timestep + latents: Current latents + do_true_cfg: Whether CFG is enabled + + Returns: + Updated latents (synchronized across all CFG ranks) + """ + # Automatically detect CFG parallel configuration + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + # Only rank 0 computes the scheduler step + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + + # Broadcast the updated latents to all ranks + latents = latents.contiguous() + cfg_group.broadcast(latents, src=0) + else: + # No CFG parallel: directly compute scheduler step + latents = self.scheduler_step(noise_pred, t, latents) + + return latents diff --git a/vllm_omni/diffusion/distributed/comm.py b/vllm_omni/diffusion/distributed/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..f73b8442f36f148ea15bdeca25bdfa271c4ec97c --- /dev/null +++ b/vllm_omni/diffusion/distributed/comm.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team & Jiarui Fang +# from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/comm/all_to_all.py +from typing import Any + +import torch +import torch.distributed as dist +from torch import Tensor + +from vllm_omni.platforms import current_omni_platform + +__all__ = ["all_to_all_4D", "all_to_all_5D", "SeqAllToAll4D", "SeqAllToAll5D", "RingComm"] + + +def all_to_all_4D( + input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group (torch.distributed.ProcessGroup): torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert input.dim() == 4, f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous() + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + current_omni_platform.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> + # (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + current_omni_platform.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -transpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + use_sync: bool = False, + ) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + +def all_to_all_5D( + input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs) + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group (torch.distributed.ProcessGroup): torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) + """ + assert input.dim() == 5, f"input must be 5D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 3 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, 3, hc, hs) output: (bs, seqlen, 3, hc/P, hs) + bs, shard_seqlen, t_cnt, hc, hs = input.shape + + assert t_cnt == 3 + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, 3, hc, hs) -reshape-> (bs, seq_len/P, 3, P, hc/P, hs) -transpose(0,3)-> + # (P, seq_len/P, 3, bs, hc/P, hs) + input_t = input.reshape(bs, shard_seqlen, 3, seq_world_size, shard_hc, hs).transpose(0, 3).contiguous() + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + current_omni_platform.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, 3, bs, shard_hc, hs) + + # (seq_len, 3, bs, hc/P, hs) -trans-> (bs, seq_len, 3, hc/P, hs) + output = output.transpose(0, 2).transpose(1, 2).contiguous() + + return output.reshape(bs, seqlen, 3, shard_hc, hs).contiguous() + elif scatter_idx == 1 and gather_idx == 3: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, _, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, 3, hc/P, hs) -reshape-> (bs, P, seq_len/P, 3, hc/P, hs) -transpose(0, 4)-> + # (hc/P, P, seqlen/P, 3, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, 3, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, 3, shard_hc, hs) + .transpose(0, 4) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, 3, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + current_omni_platform.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, 3, bs, hs) + + # (hc, seqlen/N, bs, hs) -transpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 3).contiguous() + + return output.reshape(bs, shard_seqlen, 3, hc, hs).contiguous() + else: + raise RuntimeError("scatter_idx must be 1 or 3 and gather_idx must be 1 or 3") + + +class SeqAllToAll5D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int = 3, + gather_idx: int = 1, + use_sync: bool = False, + ) -> Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + + return all_to_all_5D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + +class RingComm: + """Ring communication utility for Ring Attention P2P communication.""" + + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, to_send: torch.Tensor, recv_tensor: torch.Tensor | None = None) -> torch.Tensor: + # Ensure to_send is contiguous for P2P + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + if recv_tensor is None: + # Create a contiguous buffer for receiving + res = torch.empty_like(to_send, memory_format=torch.contiguous_format) + # print(f"send_recv: empty_like {to_send.shape}") + else: + res = recv_tensor + if not res.is_contiguous(): + res = res.contiguous() + + send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..b722f61c07de9c489a1846b76893c6f7bcbd805b --- /dev/null +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -0,0 +1,938 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import pickle +from collections import namedtuple +from typing import Any + +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup +from vllm.logger import init_logger + +from vllm_omni.diffusion import envs +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + +env_info = envs.PACKAGES_CHECKER.get_packages_info() + + +def _split_tensor_dict( + tensor_dict: dict[str, torch.Tensor | Any], prefix: str = "" +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: list[tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, "Avoid having '%' in key as it is used as a separator for nested entries." + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append((prefix + key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict(value, prefix + key + "%") + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: list[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = current_omni_platform.get_torch_device(local_rank) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, op=op, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> torch.Tensor | list[torch.Tensor]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty(input_size, dtype=input_.dtype, device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, input_, group=self.device_group) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape( + [ + world_size, + ] + + input_size + ) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1).narrow(0, input_.numel() * i, input_.numel()).view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, gather_list, dst=self.ranks[dst], group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, src=self.ranks[src], group=self.device_group) + return input_ + + def broadcast_object(self, obj: Any | None = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], src=self.ranks[src], group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, src=self.ranks[src], group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, src=self.ranks[src], group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank." + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], dtype=torch.long, device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, src=self.ranks[src], group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv(object_tensor, src=self.ranks[src], group=self.cpu_group) + + assert rank_object == rank_size, "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any] | None = None, + src: int = 0, + group: ProcessGroup | None = None, + metadata_group: ProcessGroup | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=metadata_group, async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, src=src, group=group, async_op=True) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance(tensor_dict, dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict(self, src: int | None = None) -> dict[str, torch.Tensor | Any] | None: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), + ) + + def recv(self, size: torch.Size, dtype: torch.dtype, src: int | None = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + (self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class PipelineGroupCoordinator(GroupCoordinator): + """ + available attributes: + rank: int # global rank + ranks: list[int] # global ranks in the group + world_size: int # size of the group + difference between `local_rank` and `rank_in_group`: + if we have a group of size 4 across two nodes: + Process | Node | Rank | Local Rank | Rank in Group + 0 | 0 | 0 | 0 | 0 + 1 | 0 | 1 | 1 | 1 + 2 | 1 | 2 | 0 | 2 + 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + """ + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + self.cpu_groups = [] + self.device_groups = [] + if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1: + for ranks in group_ranks: + device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + # when pipeline parallelism is 2, we need to create two groups to avoid + # communication stall. + # *_group_0_1 represents the group for communication from device 0 to + # device 1. + # *_group_1_0 represents the group for communication from device 1 to + # device 0. + elif len(group_ranks[0]) == 2: + for ranks in group_ranks: + device_group_0_1 = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + device_group_1_0 = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo") + cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_groups = [device_group_0_1, device_group_1_0] + self.cpu_groups = [cpu_group_0_1, cpu_group_1_0] + self.device_group = device_group_0_1 + self.cpu_group = cpu_group_0_1 + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = current_omni_platform.get_torch_device(local_rank) + + self.recv_buffer_set: bool = False + self.recv_tasks_queue: list[tuple[str, int]] = [] + self.receiving_tasks: list[tuple[torch.distributed.Work, str, int]] = [] + self.dtype: torch.dtype | None = None + self.num_pipefusion_patches: int | None = None + + self.recv_shape: dict[str, dict[int, torch.Size]] = {} + self.send_shape: dict[str, dict[int, torch.Size]] = {} + self.recv_buffer: dict[str, dict[int, torch.Size]] = {} + + self.skip_tensor_recv_buffer_set: bool = False + self.recv_skip_tasks_queue: list[int | tuple[str, int]] = [] + self.receiving_skip_tasks: list[tuple[torch.distributed.Work, str, int]] = [] + self.skip_tensor_recv_buffer: list[torch.Tensor] | torch.Tensor | None = None + self.skip_device_group = None + for ranks in group_ranks: + skip_device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) + if self.rank in ranks: + self.skip_device_group = skip_device_group + assert self.skip_device_group is not None + + def reset_buffer(self): + self.recv_tasks_queue = [] + self.receiving_tasks = [] + self.recv_shape = {} + self.send_shape = {} + self.recv_buffer = {} + + self.recv_skip_tasks_queue = [] + self.receiving_skip_tasks = [] + self.skip_tensor_recv_buffer = {} + + def set_config(self, dtype: torch.dtype): + self.dtype = dtype + + def set_recv_buffer( + self, + num_pipefusion_patches: int, + patches_shape_list: list[list[int]], + feature_map_shape: list[int], + dtype: torch.dtype, + ): + assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object" + assert isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1, ( + "num_pipefusion_patches must be greater than or equal to 1" + ) + self.dtype = dtype + self.num_pipefusion_patches = num_pipefusion_patches + self.recv_buffer = [torch.zeros(*shape, dtype=self.dtype, device=self.device) for shape in patches_shape_list] + self.recv_buffer.append(torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)) + self.recv_buffer_set = True + + def set_extra_tensors_recv_buffer( + self, + name: str, + shape: list[int], + num_buffers: int = 1, + dtype: torch.dtype = torch.float16, + ): + self.extra_tensors_recv_buffer[name] = [ + torch.zeros(*shape, dtype=dtype, device=self.device) for _ in range(num_buffers) + ] + + def _check_shape_and_buffer( + self, + tensor_send_to_next=None, + recv_prev=False, + name: str | None = None, + segment_idx: int = 0, + ): + send_flag = False + name = name or "latent" + if tensor_send_to_next is not None: + shape_list = self.send_shape.get(name, None) + if shape_list is None: + self.send_shape[name] = {segment_idx: tensor_send_to_next.shape} + send_flag = True + elif shape_list.get(segment_idx, None) is None: + self.send_shape[name][segment_idx] = tensor_send_to_next.shape + send_flag = True + + recv_flag = False + if recv_prev: + shape_list = self.recv_shape.get(name, None) + if shape_list is None: + recv_flag = True + elif shape_list.get(segment_idx, None) is None: + recv_flag = True + + recv_prev_shape = self._communicate_shapes( + tensor_send_to_next=tensor_send_to_next if send_flag else None, + recv_prev=recv_flag, + ) + + if recv_flag: + if self.recv_shape.get(name, None) is None: + self.recv_shape[name] = {segment_idx: recv_prev_shape} + else: + self.recv_shape[name][segment_idx] = recv_prev_shape + + if self.recv_buffer.get(name, None) is None: + self.recv_buffer[name] = { + segment_idx: torch.zeros(recv_prev_shape, device=self.device, dtype=self.dtype) + } + else: + if self.recv_buffer[name].get(segment_idx, None) is not None: + logger.warning(f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating...") + self.recv_buffer[name][segment_idx] = torch.zeros(recv_prev_shape, device=self.device, dtype=self.dtype) + + def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): + """Communicate tensor shapes between stages. Used to communicate + tensor shapes before the actual tensor communication happens. + + Args: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + """ + + ops = [] + if recv_prev: + recv_prev_dim_tensor = torch.empty((1), device=self.device, dtype=torch.int64) + recv_prev_dim_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_dim_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_dim_op) + + if tensor_send_to_next is not None: + send_next_dim_tensor = torch.tensor(tensor_send_to_next.dim(), device=self.device, dtype=torch.int64) + send_next_dim_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_dim_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_dim_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + # should take this out once the bug with batch_isend_irecv is resolved. + current_omni_platform.synchronize() + + ops = [] + recv_prev_shape_tensor = None + if recv_prev: + recv_prev_shape_tensor = torch.empty( + torch.Size(recv_prev_dim_tensor), device=self.device, dtype=torch.int64 + ) + recv_prev_shape_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_shape_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_shape_op) + + if tensor_send_to_next is not None: + send_next_shape_tensor = torch.tensor(tensor_send_to_next.size(), device=self.device, dtype=torch.int64) + send_next_shape_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_shape_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_shape_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + current_omni_platform.synchronize() + + recv_prev_shape = [0, 0, 0] + if recv_prev_shape_tensor is not None: + recv_prev_shape = recv_prev_shape_tensor + return torch.Size(recv_prev_shape) + + def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) + self._pipeline_isend(tensor).wait() + + def pipeline_isend(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) + self._pipeline_isend(tensor) + + def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + name = name or "latent" + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self._pipeline_irecv(self.recv_buffer[name][idx]).wait() + return self.recv_buffer[name][idx] + + def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"): + name = name or "latent" + self.recv_tasks_queue.append((name, idx)) + + def recv_next(self): + if len(self.recv_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_tasks_queue) > 0: + name, idx = self.recv_tasks_queue.pop(0) + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self.receiving_tasks.append((self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx)) + + def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + assert len(self.receiving_tasks) > 0, "No tasks to receive, call add_pipeline_recv_task first" + receiving_task = self.receiving_tasks.pop(0) + receiving_task[0].wait() + assert receiving_task[1] == name and receiving_task[2] == idx, "Received tensor does not match the requested" + return self.recv_buffer[name][idx] + + def _pipeline_irecv(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, + src=self.prev_rank, + group=(self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), + ) + + def _pipeline_isend(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, + dst=self.next_rank, + group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), + ) + + def set_skip_tensor_recv_buffer( + self, + patches_shape_list: list[list[int]], + feature_map_shape: list[int], + ): + self.skip_tensor_recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) for shape in patches_shape_list + ] + self.skip_tensor_recv_buffer.append(torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)) + self.skip_tensor_recv_buffer_set = True + + def pipeline_send_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor).wait() + + def pipeline_isend_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor) + + def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor: + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait() + return self.skip_tensor_recv_buffer[idx] + + def add_pipeline_recv_skip_task(self, idx: int = -1): + self.recv_skip_tasks_queue.append(idx) + + def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor: + assert len(self.receiving_skip_tasks) > 0, "No tasks to receive, call add_pipeline_recv_skip_task first" + receiving_skip_task = self.receiving_skip_tasks.pop(0) + receiving_skip_task[0].wait() + assert receiving_skip_task[2] == idx, "Received tensor does not match the requested" + return self.skip_tensor_recv_buffer[idx] + + def recv_skip_next(self): + if len(self.recv_skip_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_skip_tasks_queue) > 0: + task = self.recv_skip_tasks_queue.pop(0) + idx = task + self.receiving_skip_tasks.append( + ( + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]), + None, + idx, + ) + ) + + def _pipeline_irecv_skip(self, tensor: torch.tensor): + return torch.distributed.irecv(tensor, src=self.skip_rank, group=self.skip_device_group) + + def _pipeline_isend_skip(self, tensor: torch.tensor): + return torch.distributed.isend(tensor, dst=self.skip_rank, group=self.skip_device_group) + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + + ulysses_group = kwargs.get("ulysses_group", None) + ring_group = kwargs.get("ring_group", None) + if ulysses_group is None: + raise RuntimeError( + "Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" + ) + if ring_group is None: + raise RuntimeError( + "Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" + ) + self.ulysses_group = ulysses_group + self.ring_group = ring_group + + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) diff --git a/vllm_omni/diffusion/distributed/parallel_state.py b/vllm_omni/diffusion/distributed/parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..2f7c2f85246616cfebdad68e2c65a658a53045f2 --- /dev/null +++ b/vllm_omni/diffusion/distributed/parallel_state.py @@ -0,0 +1,754 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/xdit-project/xDiT/blob/main/xfuser/core/distributed/utils.py +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM-Omni distributed state. + +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model parallelism, + you can skip the model parallel initialization and destruction steps. +""" + +import torch +import torch.distributed +import vllm.distributed.parallel_state as vllm_parallel_state +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size +from vllm.logger import init_logger + +from vllm_omni.diffusion import envs +from vllm_omni.platforms import current_omni_platform + +from .group_coordinator import ( + GroupCoordinator, + PipelineGroupCoordinator, + SequenceParallelGroupCoordinator, +) + +env_info = envs.PACKAGES_CHECKER.get_packages_info() + +HAS_FLASH_ATTN = env_info["has_flash_attn"] + +logger = init_logger(__name__) + + +_WORLD: GroupCoordinator | None = None +# get _TP from vllm.distributed.parallel_state +_SP: SequenceParallelGroupCoordinator | None = None +_PP: PipelineGroupCoordinator | None = None +_CFG: GroupCoordinator | None = None +_DP: GroupCoordinator | None = None +_DIT: GroupCoordinator | None = None +_VAE: GroupCoordinator | None = None + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: list[int], mask: list[bool] +) -> list[list[int]]: + r"""Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (list[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (list[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + local_rank satisfy the following equation: + global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) + tp_rank \in [0, tp_size) + dp_rank \in [0, dp_size) + pp_rank \in [0, pp_size) + + If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. + For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the + dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) + The tp_rank and pp_rank will be combined to form the `dp_group_index`. + dp_group_index = tp_rank + pp_rank * tp_size (2) + + So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in + range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the + equation (1). + + This function solve this math problem. + + For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], + and the mask = [False, True, False]. Then, + dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 + dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 + ... + dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 + + dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] + dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] + ... + dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] + """ + + def prefix_product(a: list[int], init=1) -> list[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: list[int], b: list[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert sum([x * y for x, y in zip(idx, stride[:-1])]) == index, ( + f"idx {index} with shape {shape} mismatch the return idx {idx}" + ) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator: + def __init__( + self, + tp: int, + sp: int, + pp: int, + cfg: int, + dp: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.pp = pp + self.cfg = cfg + self.dp = dp + self.rank_offset = rank_offset + self.world_size = tp * sp * pp * cfg * dp + + self.name_to_size = { + "tp": self.tp, + "sp": self.sp, + "pp": self.pp, + "cfg": self.cfg, + "dp": self.dp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), " + f"but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks + + +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return get_sp_group().rank_in_group + + +def get_ulysses_parallel_world_size(): + return get_sp_group().ulysses_world_size + + +def get_ulysses_parallel_rank(): + return get_sp_group().ulysses_rank + + +def get_ring_parallel_world_size(): + return get_sp_group().ring_world_size + + +def get_ring_parallel_rank(): + return get_sp_group().ring_rank + + +# PP +def get_pp_group() -> PipelineGroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + +def get_pipeline_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return get_pp_group().world_size + + +def get_pipeline_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return get_pp_group().rank_in_group + + +def is_pipeline_first_stage(): + """Return True if in the first pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == 0 + + +def is_pipeline_last_stage(): + """Return True if in the last pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert _CFG is not None, "classifier_free_guidance parallel group is not initialized" + return _CFG + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + return get_cfg_group().rank_in_group + + +# DP +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, "pipeline model parallel group is not initialized" + return _DP + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_dp_group().world_size + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_dp_group().rank_in_group + + +def is_dp_last_group(): + """Return True if in the last data parallel group, False otherwise.""" + return ( + get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1) + and get_classifier_free_guidance_rank() == (get_classifier_free_guidance_world_size() - 1) + and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + ) + + +def get_dit_world_size(): + """Return world size for the DiT model (excluding VAE).""" + return ( + get_data_parallel_world_size() + * get_classifier_free_guidance_world_size() + * get_sequence_parallel_world_size() + * get_pipeline_parallel_world_size() + * get_tensor_model_parallel_world_size() + ) + + +# Add VAE getter functions +def get_vae_parallel_group() -> GroupCoordinator: + assert _VAE is not None, "VAE parallel group is not initialized" + return _VAE + + +def get_vae_parallel_world_size(): + """Return world size for the VAE parallel group.""" + return get_vae_parallel_group().world_size + + +def get_vae_parallel_rank(): + """Return my rank for the VAE parallel group.""" + return get_vae_parallel_group().rank_in_group + + +# * SET + + +def init_world_group(ranks: list[int], local_rank: int, backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str | None = None, +): + if backend is None: + backend = current_omni_platform.dist_backend + logger.debug( + "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing distributed environment" + ) + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + device_id = torch.distributed.get_rank() % current_omni_platform.get_device_count() + current_omni_platform.set_device(current_omni_platform.get_torch_device(device_id)) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size" + ) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _DP is not None + and _CFG is not None + and _SP is not None + and _PP is not None + and vllm_parallel_state._TP is not None + ) + + +def init_model_parallel_group( + group_ranks: list[list[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + assert parallel_mode in [ + "data", + "pipeline", + "tensor", + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "pipeline": + return PipelineGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + elif parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_dit_group( + dit_parallel_size: int, + backend: str, +): + global _DIT + _DIT = torch.distributed.new_group(ranks=list(range(dit_parallel_size)), backend=backend) + + +def get_dit_group(): + assert _DIT is not None, "DIT group is not initialized" + return _DIT + + +def init_vae_group( + dit_parallel_size: int, + vae_parallel_size: int, + backend: str, +): + # Initialize VAE group first + global _VAE + assert _VAE is None, "VAE parallel group is already initialized" + vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size)) + _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) + + +# adapted from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/globals.py +def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True): + """ + sp_ulysses_degree x sp_ring_degree = seq_parallel_size + (ulysses_degree, dp_size) + """ + sp_size = sp_ring_degree * sp_ulysses_degree + dp_size = world_size // sp_size + + assert world_size % sp_size == 0, f"world_size {world_size} % sp_size {sp_ulysses_degree} == 0" + + num_ulysses_pgs = sp_ring_degree # world_size // sp_ulysses_degree + num_ring_pgs = sp_ulysses_degree # world_size // sp_ring_degree + + if use_ulysses_low: + for dp_rank in range(dp_size): + offset = dp_rank * sp_size + for i in range(num_ulysses_pgs): + ulysses_ranks = list( + range( + i * sp_ulysses_degree + offset, + (i + 1) * sp_ulysses_degree + offset, + ) + ) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + for i in range(num_ring_pgs): + ring_ranks = list(range(i + offset, sp_size + offset, num_ring_pgs)) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + + else: + for dp_rank in range(dp_size): + offset = dp_rank * sp_size + for i in range(num_ring_pgs): + ring_ranks = list(range(i * sp_ring_degree + offset, (i + 1) * sp_ring_degree + offset)) + group = torch.distributed.new_group(ring_ranks) + if rank in ring_ranks: + ring_pg = group + + for i in range(num_ulysses_pgs): + ulysses_ranks = list(range(i + offset, sp_size + offset, num_ulysses_pgs)) + group = torch.distributed.new_group(ulysses_ranks) + if rank in ulysses_ranks: + ulyssess_pg = group + + return ulyssess_pg, ring_pg + + +def initialize_model_parallel( + data_parallel_size: int = 1, + cfg_parallel_size: int = 1, + sequence_parallel_size: int | None = None, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + vae_parallel_size: int = 0, + backend: str | None = None, +) -> None: + if backend is None: + backend = current_omni_platform.dist_backend + """ + Initialize model parallel groups. + + Arguments: + data_parallel_size: number of data parallelism groups. + cfg_parallel_size: number of GPUs used for Classifier Free Guidance (CFG) parallelism. + sequence_parallel_size: number of GPUs used for sequence parallelism. + sequence_parallel_size = ulysses_degree * ring_degree + ulysses_degree: number of GPUs used for ulysses sequence parallelism. + ring_degree: number of GPUs used for ring sequence parallelism. + tensor_parallel_size: number of GPUs used for tensor parallelism. + pipeline_parallel_size: number of GPUs used for pipeline parallelism. + backend: distributed backend of pytorch collective comm. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize + split batch caused by CFG, and 2 GPUs to parallelize sequence. + + dp_size (2) * cfg_size (2) * sp_size (2) * pp_size (2) = 16. + + The present function will create 8 data-parallel groups, + 8 CFG group, 8 pipeline-parallel group, and + 8 sequence-parallel groups: + 8 data-parallel groups: + [g0, g8], [g1, g9], [g2, g10], [g3, g11], + [g4, g12], [g5, g13], [g6, g14], [g7, g15] + 8 CFG-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7], + [g8, g12], [g9, g13], [g10, g14], [g11, g15] + 8 sequence-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], + [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 8 pipeline-parallel groups: + [g0, g2], [g4, g6], [g8, g10], [g12, g14], + [g1, g3], [g5, g7], [g9, g11], [g13, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + if sequence_parallel_size is None: + sequence_parallel_size = ring_degree * ulysses_degree + logger.info( + f"sequence_parallel_size is not provided, using ring_degree * ulysses_degree = {sequence_parallel_size}" + ) + + if sequence_parallel_size != ring_degree * ulysses_degree: + raise ValueError( + "sequence_parallel_size is not equal to ring_degree * ulysses_degree," + f" but got {sequence_parallel_size} != {ring_degree} * {ulysses_degree}" + ) + + # FIXME: Since the async p2p communication operation of NPU is not same as cuda in torch, + # the pipefusion is not ready for npu yet + if current_omni_platform.is_npu(): + assert pipeline_parallel_size == 1, "Current pipefusion is not ready for NPU" + + dit_parallel_size = ( + data_parallel_size * cfg_parallel_size * sequence_parallel_size * pipeline_parallel_size * tensor_parallel_size + ) + + if world_size < dit_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is less than " + f"tensor_parallel_size ({tensor_parallel_size}) x " + f"pipeline_parallel_size ({pipeline_parallel_size}) x" + f"sequence_parallel_size ({sequence_parallel_size}) x" + f"cfg_parallel_size " + f"({cfg_parallel_size}) x" + f"data_parallel_size ({data_parallel_size})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_size, + sequence_parallel_size, + pipeline_parallel_size, + cfg_parallel_size, + data_parallel_size, + "tp-sp-pp-cfg-dp", + ) + global _DP + assert _DP is None, "data parallel group is already initialized" + _DP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("dp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="data", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + _PP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("pp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="pipeline", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + ulysses_pg, ring_pg = set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=dit_parallel_size, + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=ulysses_pg, + ring_group=ring_pg, + ) + + assert vllm_parallel_state._TP is None, "Tensor parallel group is already initialized" + vllm_parallel_state._TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + if vae_parallel_size > 0: + init_vae_group(dit_parallel_size, vae_parallel_size, backend) + init_dit_group(dit_parallel_size, backend) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _DP + if _DP: + _DP.destroy() + _DP = None + + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + if vllm_parallel_state._TP: + vllm_parallel_state._TP.destroy() + vllm_parallel_state._TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + global _VAE + if _VAE: + _VAE.destroy() + _VAE = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def destroy_distributed_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() diff --git a/vllm_omni/diffusion/distributed/sp_plan.py b/vllm_omni/diffusion/distributed/sp_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7b1d66b0012c796db56147612ab001f94528c0 --- /dev/null +++ b/vllm_omni/diffusion/distributed/sp_plan.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM and The HuggingFace Team +# Type definitions in this module are adapted from HuggingFace diffusers library: +# diffusers/src/diffusers/models/_modeling_parallel.py +"""Sequence Parallelism configuration and plan type definitions. + +This module provides: +1. SequenceParallelConfig: Configuration for SP (ulysses_degree, ring_degree) +2. SequenceParallelInput/Output: Type definitions for _sp_plan declarations +3. Validation utilities for _sp_plan + +A _sp_plan is a dictionary that specifies how to shard/gather tensors at +different points in a model's forward pass. This allows automatic handling +of sequence parallelism without modifying the model's forward() method. + +NOTE: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) +in diffusers. We use "Sequence Parallelism" to align with vLLM-Omni terminology. + +Example: + class MyTransformer(nn.Module): + _sp_plan = { + # Split inputs before model forward + "": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + "encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + # Split RoPE embeddings after pos_embed layer + "pos_embed": { + 0: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), + }, + # Gather output after proj_out layer + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch + import torch.nn as nn + + +# ============================================================================= +# Sequence Parallel Configuration +# ============================================================================= + + +@dataclass +class SequenceParallelConfig: + """Configuration for Sequence Parallelism using vLLM-Omni's parallel state. + + This class provides a unified interface for SP configuration that integrates + with vLLM-Omni's existing SequenceParallelGroupCoordinator. Unlike diffusers' + DeviceMesh-based approach (ContextParallelConfig), this uses the existing + parallel state management. + + Note: This corresponds to `ContextParallelConfig` in diffusers library. + + Args: + ulysses_degree: Number of devices for Ulysses (All-to-All) attention. + Sequence is split across devices, with Q/K/V redistributed via + All-to-All communication. Best for moderate sequences with good + interconnect bandwidth. + ring_degree: Number of devices for Ring attention. Sequence is split + across devices, with K/V passed in a ring topology. Best for long + sequences with limited memory/bandwidth. + convert_to_fp32: Whether to convert output and LSE to float32 for + numerical stability in ring attention. + + Note: + ulysses_degree * ring_degree = sequence_parallel_size + vLLM-Omni supports hybrid Ulysses-Ring attention (both > 1). + """ + + ulysses_degree: int = 1 + ring_degree: int = 1 + convert_to_fp32: bool = True + + # Internal state - populated by setup() + _rank: int | None = None + _world_size: int | None = None + _device: torch.device | None = None + + def __post_init__(self) -> None: + if self.ulysses_degree < 1 or self.ring_degree < 1: + raise ValueError("`ulysses_degree` and `ring_degree` must be >= 1.") + + if self.ulysses_degree == 1 and self.ring_degree == 1: + raise ValueError( + "At least one of `ulysses_degree` or `ring_degree` must be > 1 to use sequence parallelism." + ) + + @property + def sequence_parallel_size(self) -> int: + """Total sequence parallel world size.""" + return self.ulysses_degree * self.ring_degree + + def get_world_size(self) -> int: + """Get the sequence parallel world size from parallel state. + + Returns: + The world size for sequence parallelism. + + Raises: + RuntimeError: If parallel state is not initialized. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size + + return get_sequence_parallel_world_size() + + def get_rank(self) -> int: + """Get the current rank in the sequence parallel group. + + Returns: + The rank within the sequence parallel group. + + Raises: + RuntimeError: If parallel state is not initialized. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_rank + + return get_sequence_parallel_rank() + + def get_ulysses_world_size(self) -> int: + """Get the Ulysses parallel world size. + + Returns: + The world size for Ulysses (All-to-All) parallelism. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_ulysses_parallel_world_size + + return get_ulysses_parallel_world_size() + + def get_ulysses_rank(self) -> int: + """Get the current rank in the Ulysses parallel group. + + Returns: + The rank within the Ulysses parallel group. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_ulysses_parallel_rank + + return get_ulysses_parallel_rank() + + def get_ring_world_size(self) -> int: + """Get the Ring parallel world size. + + Returns: + The world size for Ring attention parallelism. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_ring_parallel_world_size + + return get_ring_parallel_world_size() + + def get_ring_rank(self) -> int: + """Get the current rank in the Ring parallel group. + + Returns: + The rank within the Ring parallel group. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_ring_parallel_rank + + return get_ring_parallel_rank() + + def setup(self, rank: int, world_size: int, device: torch.device) -> None: + """Initialize the config with runtime parallel state. + + This is called automatically when sequence parallelism is enabled. + + Args: + rank: The global rank of this process. + world_size: Total world size. + device: The device for this rank. + """ + self._rank = rank + self._world_size = world_size + self._device = device + + expected_sp_size = self.ulysses_degree * self.ring_degree + actual_sp_size = self.get_world_size() + + if expected_sp_size != actual_sp_size: + raise ValueError( + f"Configuration mismatch: ulysses_degree ({self.ulysses_degree}) * " + f"ring_degree ({self.ring_degree}) = {expected_sp_size}, but " + f"actual sequence parallel world size is {actual_sp_size}." + ) + + def is_initialized(self) -> bool: + """Check if the config has been initialized with runtime state. + + Returns: + True if setup() has been called, False otherwise. + """ + return self._rank is not None + + +# ============================================================================= +# Sequence Parallel Plan Type Definitions +# ============================================================================= + + +@dataclass(frozen=True) +class SequenceParallelInput: + """Configuration for splitting an input tensor across sequence parallel ranks. + + This specifies how to shard a tensor in the pre-forward or post-forward hook + of a layer. The tensor will be split along the specified dimension. + + Note: This corresponds to `ContextParallelInput` in diffusers library. + + Args: + split_dim: The dimension along which to split the tensor. + expected_dims: Expected number of dimensions. If provided, validates that + the tensor has this many dimensions before splitting. If the tensor + has a different number of dimensions, splitting is skipped with a warning. + split_output: If True, split the output of the layer instead of the input. + This is useful for layers whose outputs should be split after preprocessing + (e.g., RoPE embeddings). + auto_pad: If True, automatically pad the tensor if its size along split_dim + is not divisible by world_size. Creates an attention mask to indicate + valid vs padding positions. The mask is stored in ForwardContext. + Note: Ring attention does not support attention mask, so auto_pad + should only be used with Ulysses SP. + + Example: + # Split hidden_states along sequence dimension (dim 1) + SequenceParallelInput(split_dim=1, expected_dims=3) + + # Split RoPE output along sequence dimension (dim 0) + SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True) + + # Split with auto-padding for variable-length sequences + SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True) + """ + + split_dim: int + expected_dims: int | None = None + split_output: bool = False + auto_pad: bool = False + + def __repr__(self) -> str: + return ( + f"SequenceParallelInput(split_dim={self.split_dim}, " + f"expected_dims={self.expected_dims}, split_output={self.split_output}, " + f"auto_pad={self.auto_pad})" + ) + + +@dataclass(frozen=True) +class SequenceParallelOutput: + """Configuration for gathering an output tensor across sequence parallel ranks. + + This specifies how to gather a tensor in the post-forward hook of a layer. + The tensor will be gathered along the specified dimension from all ranks. + + Note: This corresponds to `ContextParallelOutput` in diffusers library. + + Args: + gather_dim: The dimension along which to gather the tensor. + expected_dims: Expected number of dimensions. If provided, validates that + the tensor has this many dimensions before gathering. + + Example: + # Gather output along sequence dimension (dim 1) + SequenceParallelOutput(gather_dim=1, expected_dims=3) + """ + + gather_dim: int + expected_dims: int | None = None + + def __repr__(self) -> str: + return f"SequenceParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})" + + +@dataclass(frozen=True) +class SequenceParallelPartialInput: + """Configuration for partially splitting a tensor (e.g., split image part, keep text part). + + This is designed for models like LongCat/Qwen where RoPE embeddings need special handling: + - Text portion: kept full across all ranks (for joint attention) + - Image portion: split across ranks + + The tensor is assumed to be concatenated as [text_part, image_part] along split_dim. + + Note: This is an extension beyond diffusers' standard ContextParallelInput, + designed for vLLM-Omni's dual-stream attention models. + + Args: + split_dim: The dimension along which to split the image portion. + text_len_source: How to determine text length: + - str: Name of a forward parameter that contains text length + - int: Fixed text length value + expected_dims: Expected number of dimensions for validation. + split_output: If True, split the output instead of input. + + Example: + # Split RoPE: text portion (from txt_ids.shape[0]) kept full, image portion split + SequenceParallelPartialInput( + split_dim=0, + text_len_source="txt_ids", # Get text length from txt_ids.shape[0] + expected_dims=2, + split_output=True, + ) + + # Or with fixed text length + SequenceParallelPartialInput( + split_dim=0, + text_len_source=512, # Fixed text length + expected_dims=2, + split_output=True, + ) + """ + + split_dim: int + text_len_source: str | int + expected_dims: int | None = None + split_output: bool = False + + def __repr__(self) -> str: + return ( + f"SequenceParallelPartialInput(split_dim={self.split_dim}, " + f"text_len_source={self.text_len_source!r}, expected_dims={self.expected_dims}, " + f"split_output={self.split_output})" + ) + + +# ============================================================================= +# Type Aliases for _sp_plan Structure +# ============================================================================= + +# Any input config type +AnySequenceParallelInput = SequenceParallelInput | SequenceParallelPartialInput + +# Input specification: maps parameter names (str) or output indices (int) to split config +SequenceParallelInputType = dict[ + str | int, + AnySequenceParallelInput | list[AnySequenceParallelInput] | tuple[AnySequenceParallelInput, ...], +] + +# Output specification: single or multiple gather configs +SequenceParallelOutputType = SequenceParallelOutput | list[SequenceParallelOutput] | tuple[SequenceParallelOutput, ...] + +# Full model plan: maps module names to input/output specifications +# - Key "" refers to the model itself (root level) +# - Key "module_name" refers to a submodule +# - Key "module_name.*" refers to all children of a ModuleList +# +# Example of a complete _sp_plan: +# +# _sp_plan = { +# # Root level: split model inputs before any submodule +# "": { +# "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), +# }, +# # Submodule: split outputs of pos_embed (RoPE) layer +# "pos_embed": { +# 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # cos +# 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # sin +# }, +# # Submodule: gather outputs of proj_out layer +# "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), +# } +# +SequenceParallelModelPlan = dict[str, SequenceParallelInputType | SequenceParallelOutputType] + + +# ============================================================================= +# Validation Utilities +# ============================================================================= + + +def _is_valid_input_config(value: object) -> bool: + """Check if a value is a valid input configuration type.""" + return isinstance(value, (SequenceParallelInput, SequenceParallelPartialInput)) + + +def _is_valid_input_config_list(value: object) -> bool: + """Check if a value is a list/tuple of valid input configurations.""" + if not isinstance(value, (list, tuple)): + return False + return all(_is_valid_input_config(x) for x in value) + + +def validate_sp_plan(plan: SequenceParallelModelPlan) -> None: + """Validate a _sp_plan dictionary for correctness. + + Args: + plan: The _sp_plan dictionary to validate. + + Raises: + ValueError: If the plan is invalid. + """ + if not isinstance(plan, dict): + raise ValueError(f"_sp_plan must be a dict, got {type(plan).__name__}") + + for module_id, module_plan in plan.items(): + if not isinstance(module_id, str): + raise ValueError(f"_sp_plan keys must be strings, got {type(module_id).__name__}") + + # Check if it's an output specification (SequenceParallelOutput or list/tuple thereof) + if isinstance(module_plan, SequenceParallelOutput): + continue + if isinstance(module_plan, (list, tuple)): + if all(isinstance(x, SequenceParallelOutput) for x in module_plan): + continue + if _is_valid_input_config_list(module_plan): + # List of inputs for a specific parameter (when output is tuple) + continue + + # Otherwise, should be an input specification dict + if isinstance(module_plan, dict): + for key, value in module_plan.items(): + if not isinstance(key, (str, int)): + raise ValueError( + f"Input spec keys must be str or int, got {type(key).__name__} for module '{module_id}'" + ) + if isinstance(key, int) and not _is_valid_input_config(value): + raise ValueError( + f"Integer keys (output indices) must map to SequenceParallelInput/PartialInput, " + f"got {type(value).__name__} for module '{module_id}'[{key}]" + ) + if _is_valid_input_config(value): + if isinstance(key, int) and not value.split_output: + raise ValueError( + f"Integer keys (output indices) require split_output=True, " + f"got split_output=False for module '{module_id}'[{key}]" + ) + elif _is_valid_input_config_list(value): + pass # Valid list of input configs + else: + raise ValueError( + f"Input spec values must be SequenceParallelInput/PartialInput or list thereof, " + f"got {type(value).__name__} for module '{module_id}'['{key}']" + ) + else: + raise ValueError( + f"_sp_plan values must be dict (input spec) or SequenceParallelOutput, " + f"got {type(module_plan).__name__} for module '{module_id}'" + ) + + +def get_sp_plan_from_model(model: nn.Module) -> SequenceParallelModelPlan | None: + """Get the _sp_plan from a model if it exists. + + Args: + model: The model to get the plan from. + + Returns: + The _sp_plan dictionary, or None if not defined. + """ + plan = getattr(model, "_sp_plan", None) + if plan is not None: + validate_sp_plan(plan) + return plan diff --git a/vllm_omni/diffusion/distributed/sp_sharding.py b/vllm_omni/diffusion/distributed/sp_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..12520183c2eebfba9e9580005fc60f55f40bf545 --- /dev/null +++ b/vllm_omni/diffusion/distributed/sp_sharding.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project and The HuggingFace Team +"""Sequence Parallelism sharding utilities. + +This module provides low-level sharding and gathering functions for +Sequence Parallelism. These can be used directly in model forward methods +for semi-intrusive SP support, or internally by the SP hooks. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass, field + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) + +logger = init_logger(__name__) + + +def sp_shard( + tensor: torch.Tensor, + dim: int, + validate: bool = True, +) -> torch.Tensor: + """Shard a tensor along the specified dimension for sequence parallelism. + + The tensor is split into world_size chunks along dim, and this rank + receives its corresponding chunk. + + Args: + tensor: The tensor to shard. + dim: The dimension along which to split. + validate: If True, validate that the tensor size is divisible by world_size. + + Returns: + The shard for this rank. + + Raises: + ValueError: If validate=True and tensor size is not divisible by world_size. + + Example: + # In model forward: + hidden_states = sp_shard(hidden_states, dim=1) + """ + world_size = get_sequence_parallel_world_size() + + if world_size == 1: + return tensor + + rank = get_sequence_parallel_rank() + size = tensor.size(dim) + + if validate and size % world_size != 0: + raise ValueError( + f"Tensor size along dim {dim} ({size}) must be divisible by " + f"world_size ({world_size}) for sequence parallel sharding." + ) + + return tensor.chunk(world_size, dim=dim)[rank] + + +def sp_gather( + tensor: torch.Tensor, + dim: int, + validate: bool = True, +) -> torch.Tensor: + """Gather a tensor along the specified dimension from all sequence parallel ranks. + + The sharded tensors from all ranks are concatenated along dim. + + Args: + tensor: The local shard to gather. + dim: The dimension along which to gather. + validate: If True, validate tensor consistency (currently unused). + + Returns: + The full tensor gathered from all ranks. + + Example: + # At end of model forward: + output = sp_gather(output, dim=1) + """ + world_size = get_sequence_parallel_world_size() + + if world_size == 1: + return tensor + + sp_group = get_sp_group() + return sp_group.all_gather(tensor, dim=dim) + + +def sp_shard_with_padding( + tensor: torch.Tensor, + dim: int, + pad_value: float = 0.0, +) -> tuple[torch.Tensor, int]: + """Shard a tensor with automatic padding if not divisible by world_size. + + This is useful for variable-length sequences where padding may be needed. + + Args: + tensor: The tensor to shard. + dim: The dimension along which to split. + pad_value: Value to use for padding. + + Returns: + Tuple of (sharded_tensor, padding_size). The padding_size indicates + how much padding was added to the original tensor before sharding. + + Example: + sharded, pad_size = sp_shard_with_padding(hidden_states, dim=1) + # ... process ... + output = sp_gather(output, dim=1) + if pad_size > 0: + output = output[..., :-pad_size] # Remove padding + """ + world_size = get_sequence_parallel_world_size() + + if world_size == 1: + return tensor, 0 + + size = tensor.size(dim) + remainder = size % world_size + + if remainder == 0: + return sp_shard(tensor, dim, validate=False), 0 + + # Pad to make divisible + pad_size = world_size - remainder + pad_shape = list(tensor.shape) + pad_shape[dim] = pad_size + padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device) + tensor = torch.cat([tensor, padding], dim=dim) + + return sp_shard(tensor, dim, validate=False), pad_size + + +# NOTE: This class is a vLLM-Omni extension for +# debugging intrusive SP implementations. +# Purpose: +# - Help developers detect bugs when implementing intrusive SP +# - Verify that every sharded tensor is properly gathered +# - Warn about common mistakes (double shard, gather without shard) +# +# When to use: +# - During development/debugging of intrusive SP code +# - In tests to verify shard/gather correctness +@dataclass +class ShardingValidator: + """Validator for tracking and verifying sharding operations. + + This class helps ensure that sharding and gathering operations are + correctly paired in model forward passes. It tracks which tensors + have been sharded and verifies that they are properly gathered. + + Usage: + validator = ShardingValidator() + with validator.track(): + hidden_states = validator.shard(hidden_states, "hidden_states", dim=1) + # ... model computation ... + output = validator.gather(output, "hidden_states", dim=1) + validator.validate() # Raises if any shard was not gathered + + Attributes: + _sharded: Set of tensor names that have been sharded. + _gathered: Set of tensor names that have been gathered. + _enabled: Whether tracking is currently enabled. + """ + + _sharded: set[str] = field(default_factory=set) + _gathered: set[str] = field(default_factory=set) + _enabled: bool = False + + def reset(self) -> None: + """Reset the validator state for a new forward pass.""" + self._sharded.clear() + self._gathered.clear() + + @contextmanager + def track(self): + """Context manager to enable tracking for a forward pass.""" + self._enabled = True + self.reset() + try: + yield + finally: + self._enabled = False + + def shard( + self, + tensor: torch.Tensor, + name: str, + dim: int, + validate_divisible: bool = True, + ) -> torch.Tensor: + """Shard a tensor and track the operation. + + Args: + tensor: The tensor to shard. + name: A name to identify this tensor for validation. + dim: The dimension along which to split. + validate_divisible: If True, validate divisibility. + + Returns: + The sharded tensor. + """ + if self._enabled: + if name in self._sharded: + logger.warning(f"Tensor '{name}' sharded multiple times") + self._sharded.add(name) + + return sp_shard(tensor, dim, validate=validate_divisible) + + def gather( + self, + tensor: torch.Tensor, + name: str, + dim: int, + ) -> torch.Tensor: + """Gather a tensor and track the operation. + + Args: + tensor: The local shard to gather. + name: The name used when sharding (for validation). + dim: The dimension along which to gather. + + Returns: + The gathered tensor. + """ + if self._enabled: + if name not in self._sharded: + logger.warning(f"Tensor '{name}' gathered without being sharded") + self._gathered.add(name) + + return sp_gather(tensor, dim) + + def validate(self) -> None: + """Validate that all sharded tensors were gathered. + + Raises: + ValueError: If any sharded tensor was not gathered. + """ + unmatched = self._sharded - self._gathered + if unmatched: + raise ValueError( + f"The following tensors were sharded but not gathered: {unmatched}. " + f"This may indicate a bug in the model's SP implementation." + ) + + +# Global validator instance for convenience +_global_validator = ShardingValidator() + + +def get_sharding_validator() -> ShardingValidator: + """Get the global sharding validator instance. + + Returns: + The global ShardingValidator. + """ + return _global_validator diff --git a/vllm_omni/diffusion/distributed/utils.py b/vllm_omni/diffusion/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..92d00ef4b4d25a6f3ec5e32c02439599b2ed8046 --- /dev/null +++ b/vllm_omni/diffusion/distributed/utils.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +import torch + +from vllm_omni.platforms import current_omni_platform + + +def get_local_device() -> torch.device: + """Return the torch device for the current rank based on detected device type.""" + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + return current_omni_platform.get_torch_device(local_rank) diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py new file mode 100644 index 0000000000000000000000000000000000000000..dd566b0595e76077ddc3ca6e264fed8ba8bc9f25 --- /dev/null +++ b/vllm_omni/diffusion/envs.py @@ -0,0 +1,113 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/xdit-project/xDiT/blob/main/xfuser/envs.py +import os +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger + +from vllm_omni.platforms import current_omni_platform + +if TYPE_CHECKING: + MASTER_ADDR: str = "" + MASTER_PORT: int | None = None + CUDA_HOME: str | None = None + LOCAL_RANK: int = 0 + +environment_variables: dict[str, Callable[[], Any]] = { + # ================== Runtime Env Vars ================== + # used in distributed environment to determine the master address + "MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""), + # used in distributed environment to manually set the communication port + "MASTER_PORT": lambda: (int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None), + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), +} + +logger = init_logger(__name__) + + +class PackagesEnvChecker: + """Singleton class for checking package availability.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.initialize() + return cls._instance + + def initialize(self): + packages_info = {} + packages_info["has_flash_attn"] = self._check_flash_attn(packages_info) + self.packages_info = packages_info + + def _check_flash_attn(self, packages_info) -> bool: + """Check if flash attention is available and compatible.""" + platform = current_omni_platform + + # Flash attention requires CUDA-like platforms (CUDA or ROCm) + if not platform.is_cuda_alike(): + return False + + # Check if devices are available + if platform.get_device_count() == 0: + return False + + try: + gpu_name = platform.get_device_name() + # Turing/Tesla/T4 GPUs don't support flash attention well + if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: + return False + + # Check for any FA backend: FA3 (fa3_fwd_interface, flash_attn_interface) or FA2 (flash_attn) + # Try FA3 from fa3-fwd PyPI package + try: + import fa3_fwd_interface # noqa: F401 + + return True + except (ImportError, ModuleNotFoundError): + pass + + # Try FA3 from flash-attention source build + try: + import flash_attn_interface # noqa: F401 + + return True + except (ImportError, ModuleNotFoundError): + pass + + # Try FA2 from flash-attn package + from flash_attn import __version__ + + if __version__ < "2.6.0": + raise ImportError("install flash_attn >= 2.6.0") + return True + except (ImportError, ModuleNotFoundError): + if not packages_info.get("has_aiter", False): + logger.warning("No Flash Attention backend found, using pytorch SDPA implementation") + return False + + def get_packages_info(self) -> dict: + """Get the packages info dictionary.""" + return self.packages_info + + +PACKAGES_CHECKER = PackagesEnvChecker() + + +def __getattr__(name): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/vllm_omni/diffusion/executor/__init__.py b/vllm_omni/diffusion/executor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/diffusion/executor/abstract.py b/vllm_omni/diffusion/executor/abstract.py new file mode 100644 index 0000000000000000000000000000000000000000..e41f41d119e8d0daf548404d5692ca634c078d18 --- /dev/null +++ b/vllm_omni/diffusion/executor/abstract.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from typing import Any + +from vllm.utils.import_utils import resolve_obj_by_qualname + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.request import OmniDiffusionRequest + + +class DiffusionExecutor(ABC): + """Abstract base class for Diffusion executors.""" + + uses_multiproc: bool = False + + @staticmethod + def get_class(od_config: OmniDiffusionConfig) -> type["DiffusionExecutor"]: + executor_class: type[DiffusionExecutor] + distributed_executor_backend = od_config.distributed_executor_backend + + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, DiffusionExecutor): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"DiffusionExecutor. Got {distributed_executor_backend}." + ) + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": + raise NotImplementedError("ray backend is not yet supported.") + elif distributed_executor_backend == "mp": + from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor + + executor_class = MultiprocDiffusionExecutor + elif distributed_executor_backend == "external_launcher": + raise NotImplementedError("external_launcher backend is not yet supported.") + elif isinstance(distributed_executor_backend, str): + try: + executor_class = resolve_obj_by_qualname(distributed_executor_backend) + except (ImportError, ValueError) as e: + raise ValueError( + f"Failed to load executor backend '{distributed_executor_backend}'. " + f"Ensure it is a valid python path. Error: {e}" + ) from e + + if not issubclass(executor_class, DiffusionExecutor): + raise TypeError( + f"distributed_executor_backend must be a subclass of DiffusionExecutor. Got {executor_class}." + ) + else: + raise ValueError(f"Unknown distributed executor backend: {distributed_executor_backend}") + return executor_class + + def __init__(self, od_config: OmniDiffusionConfig): + self.od_config = od_config + self._init_executor() + + @abstractmethod + def _init_executor(self) -> None: + """Initialize the executor (e.g., launch workers, setup IPC).""" + pass + + @abstractmethod + def add_req(self, requests: OmniDiffusionRequest) -> DiffusionOutput: + """Add requests to the execution queue.""" + pass + + @abstractmethod + def collective_rpc( + self, + method: str, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + unique_reply_rank: int | None = None, + ) -> Any: + """Execute a method on workers.""" + pass + + @abstractmethod + def check_health(self) -> None: + """Check if the executor and workers are healthy.""" + pass + + @abstractmethod + def shutdown(self) -> None: + """Shutdown the executor and release resources.""" + pass diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..a421a01eddd71fc26b1030eb6b86743eca81ae27 --- /dev/null +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -0,0 +1,197 @@ +import multiprocessing as mp +import time +import weakref +from dataclasses import dataclass +from typing import Any + +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, DiffusionOutput +from vllm_omni.diffusion.executor.abstract import DiffusionExecutor +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.scheduler import Scheduler +from vllm_omni.diffusion.worker import WorkerProc + +logger = init_logger(__name__) + + +@dataclass +class BackgroundResources: + """ + Used as a finalizer for clean shutdown. + """ + + scheduler: Scheduler | None = None + processes: list[mp.Process] | None = None + + def __call__(self): + """Clean up background resources.""" + if self.scheduler is not None: + try: + for _ in range(self.scheduler.num_workers): + self.scheduler.mq.enqueue(SHUTDOWN_MESSAGE) + self.scheduler.close() + except Exception as exc: + logger.warning("Failed to send shutdown signal: %s", exc) + if self.processes: + for proc in self.processes: + if not proc.is_alive(): + continue + proc.join(30) + if proc.is_alive(): + logger.warning("Terminating diffusion worker %s after timeout", proc.name) + proc.terminate() + proc.join(30) + + +class MultiprocDiffusionExecutor(DiffusionExecutor): + uses_multiproc: bool = True + + def _init_executor(self) -> None: + self._processes: list[mp.Process] = [] + self._closed = False + + # Initialize scheduler + self.scheduler = Scheduler() + self.scheduler.initialize(self.od_config) + broadcast_handle = self.scheduler.get_broadcast_handle() + + # Launch workers + processes, result_handle = self._launch_workers(broadcast_handle) + + if result_handle is not None: + self.scheduler.initialize_result_queue(result_handle) + else: + logger.error("Failed to get result queue handle from workers") + + self._processes = processes + + self.resources = BackgroundResources(scheduler=self.scheduler, processes=self._processes) + self._finalizer = weakref.finalize(self, self.resources) + + def _launch_workers(self, broadcast_handle): + od_config = self.od_config + logger.info("Starting server...") + + num_gpus = od_config.num_gpus + mp.set_start_method("spawn", force=True) + processes = [] + + # Launch all worker processes + scheduler_pipe_readers = [] + scheduler_pipe_writers = [] + + for i in range(num_gpus): + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_writers.append(writer) + process = mp.Process( + target=WorkerProc.worker_main, + args=( + i, # rank + od_config, + writer, + broadcast_handle, + ), + name=f"DiffusionWorker-{i}", + daemon=True, + ) + scheduler_pipe_readers.append(reader) + process.start() + processes.append(process) + + # Wait for all workers to be ready + scheduler_infos = [] + result_handle = None + for writer in scheduler_pipe_writers: + writer.close() + + for i, reader in enumerate(scheduler_pipe_readers): + try: + data = reader.recv() + except EOFError: + logger.error(f"Rank {i} scheduler is dead. Please check if there are relevant logs.") + processes[i].join() + logger.error(f"Exit code: {processes[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError("Initialization failed. Please see the error messages above.") + + if i == 0: + result_handle = data.get("result_handle") + + scheduler_infos.append(data) + reader.close() + + logger.debug("All workers are ready") + + return processes, result_handle + + def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: + return self.scheduler.add_req(request) + + def collective_rpc( + self, + method: str, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + unique_reply_rank: int | None = None, + ) -> Any: + if self._closed: + raise RuntimeError("DiffusionExecutor is closed.") + + deadline = None if timeout is None else time.monotonic() + timeout + kwargs = kwargs or {} + + # Prepare RPC request message + rpc_request = { + "type": "rpc", + "method": method, + "args": args, + "kwargs": kwargs, + "output_rank": unique_reply_rank, + } + + try: + # Broadcast RPC request to all workers via unified message queue + self.scheduler.mq.enqueue(rpc_request) + + # Determine which workers we expect responses from + num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus + + responses = [] + for _ in range(num_responses): + dequeue_timeout = None if deadline is None else (deadline - time.monotonic()) + try: + if self.scheduler.result_mq is None: + raise RuntimeError("Result queue not initialized") + + response = self.scheduler.result_mq.dequeue(timeout=dequeue_timeout) + + # Check if response indicates an error + if isinstance(response, dict) and response.get("status") == "error": + raise RuntimeError( + f"Worker failed with error '{response.get('error')}', " + "please check the stack trace above for the root cause" + ) + + responses.append(response) + except TimeoutError as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e + + return responses[0] if unique_reply_rank is not None else responses + + except Exception as e: + logger.error(f"RPC call failed: {e}") + raise + + def check_health(self) -> None: + # Simple check if processes are alive + for p in self._processes: + if not p.is_alive(): + raise RuntimeError(f"Worker process {p.name} is dead") + + def shutdown(self) -> None: + self._closed = True + self._finalizer() diff --git a/vllm_omni/diffusion/forward_context.py b/vllm_omni/diffusion/forward_context.py new file mode 100644 index 0000000000000000000000000000000000000000..02082e7b1a4f82fa4b1a7e731b8d7c95641afbd4 --- /dev/null +++ b/vllm_omni/diffusion/forward_context.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass + +from vllm.config import VllmConfig + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionMetadata, +) +from vllm_omni.diffusion.data import OmniDiffusionConfig + + +@dataclass +class ForwardContext: + """ + set forward context for diffusion models + """ + + vllm_config: VllmConfig | None = None + omni_diffusion_config: OmniDiffusionConfig | None = None + attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None = None + split_text_embed_in_sp: bool = False + # whether to split the text embed in sequence parallel, if True, the text embed will be split in sequence parallel + + # Sequence Parallel padding support + # When sequence length is not divisible by SP world size, padding is added + # These values are used by SequenceParallelGatherHook to remove padding, + # and by attention layers to create attention masks dynamically + sp_padding_size: int = 0 + # Original sequence length before padding (for removing padding in gather) + sp_original_seq_len: int | None = None + + def __post_init__(self): + pass + + +_forward_context: ForwardContext | None = None + + +def get_forward_context() -> ForwardContext: + """Get the current forward context.""" + assert _forward_context is not None, ( + "Forward context is not set. Please use `set_forward_context` to set the forward context." + ) + return _forward_context + + +def is_forward_context_available() -> bool: + return _forward_context is not None + + +def create_forward_context( + vllm_config: VllmConfig | None = None, + omni_diffusion_config: OmniDiffusionConfig | None = None, + attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None = None, + split_text_embed_in_sp: bool = False, +): + return ForwardContext( + vllm_config=vllm_config, + omni_diffusion_config=omni_diffusion_config, + attn_metadata=attn_metadata, + split_text_embed_in_sp=split_text_embed_in_sp, + ) + + +@contextmanager +def override_forward_context(forward_context: ForwardContext | None): + """A context manager that overrides the current forward context. + This is used to override the forward context for a specific + forward pass. + """ + global _forward_context + prev_context = _forward_context + _forward_context = forward_context + try: + yield + finally: + _forward_context = prev_context + + +@contextmanager +def set_forward_context( + vllm_config: VllmConfig | None = None, + omni_diffusion_config: OmniDiffusionConfig | None = None, + attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None = None, + split_text_embed_in_sp: bool = False, +): + """A context manager that stores the current forward context, + can be attention metadata, split_text_embed_in_sp, etc. + Here we can inject common logic for every model forward pass. + """ + forward_context = create_forward_context( + vllm_config=vllm_config, + omni_diffusion_config=omni_diffusion_config, + attn_metadata=attn_metadata, + split_text_embed_in_sp=split_text_embed_in_sp, + ) + # vLLM CustomOp dispatch (e.g. QKVParallelLinear) requires a global + # vLLM config set via set_current_vllm_config(). + with override_forward_context(forward_context): + if vllm_config is None: + yield + else: + # Local import to avoid importing vllm.config.vllm at module import time. + from vllm.config.vllm import set_current_vllm_config + + with set_current_vllm_config(vllm_config): + yield diff --git a/vllm_omni/diffusion/hooks/__init__.py b/vllm_omni/diffusion/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68f918383294159e48ccecd091471875eb76e878 --- /dev/null +++ b/vllm_omni/diffusion/hooks/__init__.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Hook mechanism for model forward interception.""" + +from vllm_omni.diffusion.hooks.base import ( + BaseState, + HookRegistry, + ModelHook, + StateManager, +) +from vllm_omni.diffusion.hooks.sequence_parallel import ( + SequenceParallelGatherHook, + SequenceParallelSplitHook, + apply_sequence_parallel, + disable_sequence_parallel_for_model, + enable_sequence_parallel_for_model, + remove_sequence_parallel, +) + +__all__ = [ + # Base hooks + "BaseState", + "StateManager", + "ModelHook", + "HookRegistry", + # Sequence parallel hooks (corresponds to diffusers' context_parallel) + "SequenceParallelSplitHook", + "SequenceParallelGatherHook", + "apply_sequence_parallel", + "remove_sequence_parallel", + "enable_sequence_parallel_for_model", + "disable_sequence_parallel_for_model", +] diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8b330a19f422ed4bae4bd74a22554283072648dc --- /dev/null +++ b/vllm_omni/diffusion/hooks/base.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base hook classes for model forward interception. + +This module provides the foundational hook mechanism that allows intercepting +and modifying model forward passes without invasive changes to model code. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch.nn as nn + + +class BaseState: + """Base class for hook state containers.""" + + def reset(self) -> None: # pragma: no cover - default is no-op + pass + + +class StateManager: + """Manage per-context hook state instances.""" + + def __init__(self, state_cls: Callable[[], BaseState]): + self._state_cls = state_cls + self._states: dict[str, BaseState] = {} + self._context: str = "default" + + def set_context(self, name: str) -> None: + self._context = name or "default" + + def get_state(self) -> BaseState: + if self._context not in self._states: + self._states[self._context] = self._state_cls() + return self._states[self._context] + + def reset(self) -> None: + self._states.clear() + + +class ModelHook: + """Base class for model hooks that can override a module's forward. + + Hooks can intercept the forward pass at two points: + - pre_forward: Called before the original forward, can modify args/kwargs + - post_forward: Called after the original forward, can modify output + + Subclasses can override either or both methods. The default implementations + pass through args/kwargs/output unchanged. + + For more complex behavior, override new_forward to completely replace + the forward logic. + """ + + def initialize_hook(self, module: nn.Module) -> nn.Module: + """Initialize the hook when it's registered to a module. + + Args: + module: The module this hook is being attached to. + + Returns: + The module (possibly modified). + """ + return module + + def pre_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> tuple[tuple, dict]: + """Called before the module's forward pass. + + Args: + module: The module being called. + *args: Positional arguments to forward. + **kwargs: Keyword arguments to forward. + + Returns: + Tuple of (args, kwargs) to pass to the forward method. + """ + return args, kwargs + + def post_forward(self, module: nn.Module, output: Any) -> Any: + """Called after the module's forward pass. + + Args: + module: The module that was called. + output: The output from the forward method. + + Returns: + The (possibly modified) output. + """ + return output + + def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any: + """Override the module's forward pass completely. + + The default implementation calls pre_forward, then the original forward, + then post_forward. Override this method for more complex behavior. + + Args: + module: The module being called. + *args: Positional arguments to forward. + **kwargs: Keyword arguments to forward. + + Returns: + The output of the forward pass. + """ + args, kwargs = self.pre_forward(module, *args, **kwargs) + output = module._original_forward(*args, **kwargs) # type: ignore[attr-defined] + return self.post_forward(module, output) + + def reset_state(self, module: nn.Module) -> nn.Module: + """Reset any state associated with this hook. + + Args: + module: The module this hook is attached to. + + Returns: + The module. + """ + return module + + +@dataclass +class _WrappedForward: + """Wrapper that intercepts forward calls and dispatches to hooks.""" + + module: nn.Module + + def __call__(self, *args: Any, **kwargs: Any): + registry: HookRegistry | None = getattr(self.module, "_hook_registry", None) + if registry is None or not registry._hooks: + return self.module._original_forward(*args, **kwargs) + return registry.dispatch(*args, **kwargs) + + +class HookRegistry: + """Registry of hooks attached to a module. + + Manages multiple hooks that can intercept a module's forward pass. + Hooks are called in sorted order by name for determinism. + """ + + def __init__(self, module: nn.Module): + self.module = module + self._hooks: dict[str, ModelHook] = {} + + @classmethod + def get_or_create(cls, module: nn.Module) -> HookRegistry: + """Get existing registry or create a new one for the module. + + Args: + module: The module to get/create a registry for. + + Returns: + The HookRegistry for this module. + """ + registry: HookRegistry | None = getattr(module, "_hook_registry", None) + if registry is None: + registry = cls(module) + setattr(module, "_hook_registry", registry) + + # Wrap module.forward once so hooks can intercept calls. + if not hasattr(module, "_original_forward"): + module._original_forward = module.forward # type: ignore[attr-defined] + module.forward = _WrappedForward(module) # type: ignore[assignment] + + return registry + + def register_hook(self, name: str, hook: ModelHook) -> None: + """Register a hook with the given name. + + Args: + name: Unique name for this hook. + hook: The hook instance to register. + """ + hook.initialize_hook(self.module) + self._hooks[name] = hook + + def remove_hook(self, name: str) -> None: + """Remove a hook by name. + + Args: + name: The name of the hook to remove. + """ + if name in self._hooks: + del self._hooks[name] + + def get_hook(self, name: str) -> ModelHook | None: + """Get a hook by name. + + Args: + name: The name of the hook. + + Returns: + The hook if found, None otherwise. + """ + return self._hooks.get(name) + + def dispatch(self, *args: Any, **kwargs: Any) -> Any: + """Dispatch a forward call through registered hooks. + + Currently supports a single active hook. Multiple hooks are called + in sorted order by name, with each hook's output passed to the next. + + Args: + *args: Positional arguments to forward. + **kwargs: Keyword arguments to forward. + + Returns: + The output of the forward pass. + """ + if not self._hooks: + return self.module._original_forward(*args, **kwargs) # type: ignore[attr-defined] + + # For single hook case, call directly + if len(self._hooks) == 1: + hook = next(iter(self._hooks.values())) + return hook.new_forward(self.module, *args, **kwargs) + + # For multiple hooks, chain them in sorted order + # Each hook can modify args/kwargs via pre_forward + sorted_hooks = sorted(self._hooks.items(), key=lambda x: x[0]) + + # Apply all pre_forward hooks + for _, hook in sorted_hooks: + args, kwargs = hook.pre_forward(self.module, *args, **kwargs) + + # Call original forward + output = self.module._original_forward(*args, **kwargs) # type: ignore[attr-defined] + + # Apply all post_forward hooks in reverse order + for _, hook in reversed(sorted_hooks): + output = hook.post_forward(self.module, output) + + return output + + def reset_hook(self, name: str) -> None: + """Reset a hook's state by name. + + Args: + name: The name of the hook to reset. + """ + hook = self._hooks.get(name) + if hook is not None: + hook.reset_state(self.module) diff --git a/vllm_omni/diffusion/hooks/sequence_parallel.py b/vllm_omni/diffusion/hooks/sequence_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2c2b7550607d90d5df64a6e0b0751ac90a75a1 --- /dev/null +++ b/vllm_omni/diffusion/hooks/sequence_parallel.py @@ -0,0 +1,700 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project and the HuggingFace Team. +# All rights reserved. +# +# This module is adapted from HuggingFace diffusers library: +# diffusers/src/diffusers/hooks/context_parallel.py +# +# NOTE: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers. +# We use the term "Sequence Parallelism" to align with vLLM-Omni's existing terminology. +# +# Key adaptations from diffusers: +# - ModuleForwardMetadata: parameter lookup logic (adapted) +# - SequenceParallelSplitHook/GatherHook: hook structure (adapted from ContextParallel*) +# - apply_sequence_parallel: registration logic (adapted from apply_context_parallel) +# +# Key differences from diffusers: +# - Uses vLLM-Omni's SequenceParallelGroupCoordinator instead of DeviceMesh +# - Uses sp_shard/sp_gather from sp_sharding.py instead of funcol operations +# - Supports Ulysses + Ring hybrid parallelism +# +"""Sequence Parallelism hooks for non-intrusive SP support. + +This module implements the hook-based mechanism for applying sequence parallelism +to models without modifying their forward() methods. + +Usage: + 1. Define _sp_plan on your model class (corresponds to diffusers' _cp_plan) + 2. Call apply_sequence_parallel(model, config, plan) to enable SP + 3. Call remove_sequence_parallel(model, plan) to disable SP + +The hooks automatically shard inputs before forward and gather outputs after, +based on the plan specification. +""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.distributed.sp_plan import ( + AnySequenceParallelInput, + SequenceParallelConfig, + SequenceParallelInput, + SequenceParallelModelPlan, + SequenceParallelOutput, + SequenceParallelPartialInput, +) +from vllm_omni.diffusion.distributed.sp_sharding import sp_gather, sp_shard +from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook + +logger = init_logger(__name__) + +# Hook name templates for identifying SP hooks +_SP_INPUT_HOOK_TEMPLATE = "sp_input---{}" +_SP_OUTPUT_HOOK_TEMPLATE = "sp_output---{}" + + +@dataclass +class ModuleForwardMetadata: + """Metadata for mapping forward() parameter names to args/kwargs positions. + + This caches the inspection of a module's forward signature to efficiently + locate parameters by name in subsequent calls. + """ + + cached_parameter_indices: dict[str, int] | None = None + _cls: type | None = None + + def _get_parameter_from_args_kwargs( + self, + identifier: str, + args: tuple = (), + kwargs: dict | None = None, + ) -> tuple[Any, bool, int | None]: + """Get a parameter value from args or kwargs by name. + + Args: + identifier: The parameter name to look up. + args: Positional arguments passed to forward. + kwargs: Keyword arguments passed to forward. + + Returns: + Tuple of (value, is_kwarg, index). + - value: The parameter value (or None if not found) + - is_kwarg: True if found in kwargs + - index: Position in args if found there + + Raises: + ValueError: If parameter not found in signature. + """ + kwargs = kwargs or {} + + # First check kwargs + if identifier in kwargs: + return kwargs[identifier], True, None + + # Check cached indices + if self.cached_parameter_indices is not None: + index = self.cached_parameter_indices.get(identifier, None) + if index is None: + raise ValueError(f"Parameter '{identifier}' not found in cached indices.") + if index < len(args): + return args[index], False, index + return None, False, index + + # Build cache from forward signature + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) + parameters = parameters[1:] # Skip `self` + self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + + if identifier not in self.cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature.") + + index = self.cached_parameter_indices[identifier] + + if index >= len(args): + return None, False, index + + return args[index], False, index + + +def _unwrap_module(module: nn.Module) -> nn.Module: + """Unwrap a module from any wrappers to get the original class. + + Args: + module: Potentially wrapped module. + + Returns: + The unwrapped module. + """ + # Handle common wrappers + while hasattr(module, "_modules") and len(module._modules) == 1: + inner = next(iter(module._modules.values())) + if inner is not None: + module = inner + else: + break + return module + + +class SequenceParallelSplitHook(ModelHook): + """Hook for splitting inputs before a module's forward pass. + + This hook is registered to modules that need their inputs sharded + across sequence parallel ranks. It intercepts the forward call, + shards specified inputs according to the plan, and passes the + sharded inputs to the original forward. + + For split_output=True inputs, it shards the output instead. + + Supports both SequenceParallelInput (full split) and SequenceParallelPartialInput + (partial split for text/image separation). + + Note: This corresponds to `ContextParallelSplitHook` in diffusers. + """ + + def __init__( + self, + metadata: dict[str | int, AnySequenceParallelInput | list[AnySequenceParallelInput]], + config: SequenceParallelConfig, + ) -> None: + super().__init__() + self.metadata = metadata + self.config = config + self.module_forward_metadata: ModuleForwardMetadata | None = None + # Cache for text lengths resolved from kwargs + self._text_len_cache: dict[str, int] = {} + + def initialize_hook(self, module: nn.Module) -> nn.Module: + cls = _unwrap_module(module).__class__ + self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) + return module + + def pre_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> tuple[tuple, dict]: + """Shard inputs before forward.""" + args_list = list(args) + # Clear text length cache for this forward pass + self._text_len_cache.clear() + + for name, spm in self.metadata.items(): + # Skip if this is a split_output entry (handled in post_forward) + if isinstance(spm, (SequenceParallelInput, SequenceParallelPartialInput)) and spm.split_output: + continue + + # Get the parameter value + input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( + name, args_list, kwargs + ) + + if input_val is None: + continue + + # Shard the input + if isinstance(input_val, torch.Tensor): + input_val = self._prepare_sp_input(input_val, spm, args_list, kwargs) + elif isinstance(input_val, (list, tuple)): + # Handle list/tuple of tensors with per-element config + if not isinstance(spm, (list, tuple)): + raise ValueError( + f"Expected list/tuple of SequenceParallelInput for parameter '{name}' " + f"which is a list/tuple, but got {type(spm).__name__}" + ) + if len(input_val) != len(spm): + raise ValueError(f"Expected {len(spm)} elements for parameter '{name}', got {len(input_val)}") + sharded_input_val = [] + for i, x in enumerate(input_val): + if torch.is_tensor(x) and not spm[i].split_output: + x = self._prepare_sp_input(x, spm[i], args_list, kwargs) + sharded_input_val.append(x) + input_val = type(input_val)(sharded_input_val) + else: + raise ValueError(f"Unsupported input type for sharding: {type(input_val).__name__}") + + # Update args or kwargs + if is_kwarg: + kwargs[name] = input_val + elif index is not None and index < len(args_list): + args_list[index] = input_val + else: + raise ValueError(f"Failed to update parameter '{name}' after sharding.") + + # Store kwargs for post_forward to resolve text lengths + self._last_kwargs = kwargs + self._last_args = tuple(args_list) + + return tuple(args_list), kwargs + + def post_forward(self, module: nn.Module, output: Any) -> Any: + """Shard outputs for split_output=True entries.""" + is_tensor = isinstance(output, torch.Tensor) + is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output) + + if not is_tensor and not is_tensor_list: + # No tensor outputs to shard + return output + + output_list = [output] if is_tensor else list(output) + + for index, spm in self.metadata.items(): + if not isinstance(index, int): + continue + if not isinstance(spm, (SequenceParallelInput, SequenceParallelPartialInput)) or not spm.split_output: + continue + if index >= len(output_list): + raise ValueError(f"Index {index} out of bounds for output of length {len(output_list)}.") + + output_list[index] = self._prepare_sp_input(output_list[index], spm, self._last_args, self._last_kwargs) + + return output_list[0] if is_tensor else type(output)(output_list) + + def _resolve_text_len( + self, + sp_input: SequenceParallelPartialInput, + args: tuple, + kwargs: dict, + ) -> int: + """Resolve text length from the source specification.""" + source = sp_input.text_len_source + + if isinstance(source, int): + return source + + # String source - look up from kwargs or args + if source in self._text_len_cache: + return self._text_len_cache[source] + + # Try to get from kwargs/args + try: + val, _, _ = self.module_forward_metadata._get_parameter_from_args_kwargs(source, args, kwargs) + if val is None: + raise ValueError(f"Parameter '{source}' is None, cannot determine text length.") + if isinstance(val, torch.Tensor): + # TODO: Currently assumes batch_size=1, where shape[0] is sequence length. + # For batch inference support, this should be updated to handle + # shape (batch_size, seq_len, ...) where text_len varies per sample. + text_len = val.shape[0] + elif isinstance(val, int): + text_len = val + else: + raise ValueError(f"Cannot determine text length from '{source}' of type {type(val).__name__}") + self._text_len_cache[source] = text_len + return text_len + except ValueError as e: + raise ValueError(f"Failed to resolve text_len_source '{source}': {e}") from e + + def _prepare_sp_input( + self, + x: torch.Tensor, + sp_input: AnySequenceParallelInput, + args: tuple = (), + kwargs: dict | None = None, + ) -> torch.Tensor: + """Shard a tensor according to the input specification.""" + kwargs = kwargs or {} + + if sp_input.expected_dims is not None and x.dim() != sp_input.expected_dims: + logger.warning_once(f"Expected tensor with {sp_input.expected_dims} dims, got {x.dim()}. Skipping split.") + return x + + if isinstance(sp_input, SequenceParallelInput): + # Full split with optional auto-padding + if sp_input.auto_pad: + return self._shard_with_auto_pad(x, sp_input.split_dim) + return sp_shard(x, sp_input.split_dim, validate=False) + elif isinstance(sp_input, SequenceParallelPartialInput): + # Partial split: keep text portion, split image portion + text_len = self._resolve_text_len(sp_input, args, kwargs) + dim = sp_input.split_dim + + # Split tensor into text and image portions + text_part = x.narrow(dim, 0, text_len) + image_part = x.narrow(dim, text_len, x.size(dim) - text_len) + + # Only shard the image portion + image_part_sharded = sp_shard(image_part, dim, validate=False) + + # Concatenate back: [text_full, image_sharded] + return torch.cat([text_part, image_part_sharded], dim=dim) + else: + raise ValueError(f"Unsupported input config type: {type(sp_input).__name__}") + + def _shard_with_auto_pad(self, x: torch.Tensor, dim: int) -> torch.Tensor: + """Shard tensor with automatic padding and attention mask creation. + + When sequence length is not divisible by SP world size, this method: + 1. Pads the tensor to make it divisible + 2. Creates an attention mask indicating valid vs padding positions + 3. Stores the mask and padding info in ForwardContext + """ + from vllm_omni.diffusion.attention.selector import get_attn_backend + from vllm_omni.diffusion.distributed.parallel_state import ( + get_ring_parallel_world_size, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + ) + from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available + + world_size = get_sequence_parallel_world_size() + if world_size == 1: + return x + + seq_len = x.size(dim) + remainder = seq_len % world_size + + if remainder == 0: + # No padding needed + return sp_shard(x, dim, validate=False) + + # Check backend compatibility + attn_backend = get_attn_backend(-1) + if not attn_backend.supports_attention_mask: + raise ValueError( + f"Sequence length ({seq_len}) is not divisible by SP world size ({world_size}). " + f"Cannot use {attn_backend.get_name()} which does not support attention_mask. " + f"Please switch to SDPA or Ascend attention backend." + ) + + # Ring attention does not support attention_mask + if get_ring_parallel_world_size() > 1: + raise ValueError( + f"Sequence length ({seq_len}) is not divisible by SP world size ({world_size}). " + f"Cannot use Ring attention which does not support attention_mask. " + f"Please switch to Ulysses SP only." + ) + + # Calculate padding + pad_size = world_size - remainder + padded_seq_len = seq_len + pad_size + + # Pad the tensor + pad_shape = list(x.shape) + pad_shape[dim] = pad_size + padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device) + x_padded = torch.cat([x, padding], dim=dim) + + # Store padding info in forward context (only once, for primary tensor) + # Attention layers will create masks dynamically using this info + if is_forward_context_available(): + ctx = get_forward_context() + # Only set if not already set (first auto_pad tensor wins) + if ctx.sp_original_seq_len is None: + ctx.sp_padding_size = pad_size + ctx.sp_original_seq_len = seq_len + logger.debug( + f"Auto-padded sequence from {seq_len} to {padded_seq_len} " + f"(pad_size={pad_size}, world_size={world_size}, dim={dim})" + ) + + # Shard the padded tensor + rank = get_sequence_parallel_rank() + return x_padded.chunk(world_size, dim=dim)[rank] + + +class SequenceParallelGatherHook(ModelHook): + """Hook for gathering outputs after a module's forward pass. + + This hook is registered to modules that need their outputs gathered + from all sequence parallel ranks. It intercepts the output and gathers + it according to the plan specification. + + Note: This corresponds to `ContextParallelGatherHook` in diffusers. + """ + + def __init__( + self, + metadata: SequenceParallelOutput | list[SequenceParallelOutput], + config: SequenceParallelConfig, + ) -> None: + super().__init__() + if isinstance(metadata, SequenceParallelOutput): + metadata = [metadata] + self.metadata = metadata + self.config = config + + def initialize_hook(self, module: nn.Module) -> nn.Module: + return module + + def post_forward(self, module: nn.Module, output: Any) -> Any: + """Gather outputs after forward and remove padding if applied.""" + from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available + + is_tensor = isinstance(output, torch.Tensor) + + if is_tensor: + output = [output] + elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): + raise ValueError(f"Expected tensor or list/tuple of tensors, got {type(output).__name__}") + + output = list(output) + + if len(output) != len(self.metadata): + raise ValueError(f"Expected {len(self.metadata)} outputs, got {len(output)}.") + + # Check if padding was applied during split + original_seq_len = None + if is_forward_context_available(): + ctx = get_forward_context() + original_seq_len = ctx.sp_original_seq_len + + for i, spm in enumerate(self.metadata): + if spm is None: + continue + + x = output[i] + if spm.expected_dims is not None and x.dim() != spm.expected_dims: + logger.warning_once( + f"Expected output tensor with {spm.expected_dims} dims, got {x.dim()}. Skipping gather." + ) + continue + + # Gather from all ranks + gathered = sp_gather(x, spm.gather_dim, validate=False) + + # Remove padding if it was applied + if original_seq_len is not None and gathered.size(spm.gather_dim) > original_seq_len: + gathered = gathered.narrow(spm.gather_dim, 0, original_seq_len) + logger.debug(f"Removed padding: gathered shape {gathered.shape} (original_seq_len={original_seq_len})") + + output[i] = gathered + + return output[0] if is_tensor else type(output)(output) + + +def _get_submodule_by_name(model: nn.Module, name: str) -> nn.Module | list[nn.Module]: + """Get a submodule by dotted name, supporting wildcards. + + Args: + model: The root module. + name: Dotted path to submodule. Use "*" to match all children + of a ModuleList. + + Returns: + The submodule or list of submodules if wildcard used. + + Raises: + ValueError: If the path is invalid or module not found. + """ + if name.count("*") > 1: + raise ValueError("Wildcard '*' can only be used once in the name") + return _find_submodule_by_name(model, name) + + +def _find_submodule_by_name(model: nn.Module, name: str) -> nn.Module | list[nn.Module]: + """Recursive helper for _get_submodule_by_name.""" + if name == "": + return model + + first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") + + if first_atom == "*": + if not isinstance(model, nn.ModuleList): + raise ValueError("Wildcard '*' can only be used with ModuleList") + submodules = [] + for submodule in model: + subsubmodules = _find_submodule_by_name(submodule, remaining_name) + if not isinstance(subsubmodules, list): + subsubmodules = [subsubmodules] + submodules.extend(subsubmodules) + return submodules + else: + if hasattr(model, first_atom): + submodule = getattr(model, first_atom) + return _find_submodule_by_name(submodule, remaining_name) + else: + raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") + + +def apply_sequence_parallel( + module: nn.Module, + config: SequenceParallelConfig, + plan: SequenceParallelModelPlan, +) -> None: + """Apply sequence parallel hooks to a model according to the plan. + + This function registers hooks on the specified submodules to automatically + shard inputs and gather outputs for sequence parallelism. + + Note: This corresponds to `apply_context_parallel` in diffusers. + + The complete SP flow is: + 1. Input sharding (SequenceParallelSplitHook): Split sequence across SP ranks + 2. Attention parallelism (handled by vLLM-Omni's Attention layer): + - Ulysses: All-to-All over Q/K/V heads + - Ring: K/V circulation in ring topology + - Hybrid: Both (Ulysses handles head redistribution, Ring handles K/V) + 3. Output gathering (SequenceParallelGatherHook): Gather sequence from SP ranks + + Args: + module: The model to apply SP to. + config: The sequence parallel configuration. + plan: Dictionary mapping module names to input/output specifications. + + Example: + config = SequenceParallelConfig(ulysses_degree=2) + plan = { + "": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)}, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + apply_sequence_parallel(model, config, plan) + + Note: + vLLM-Omni's Attention layer automatically handles the internal + parallelism (Ulysses All-to-All or Ring attention) based on the + forward_context configuration. This function only handles input/output + sharding for the model as a whole. + """ + logger.debug( + f"Applying sequence parallel with config: ulysses={config.ulysses_degree}, " + f"ring={config.ring_degree}, plan keys: {list(plan.keys())}" + ) + + for module_id, sp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] + + logger.debug(f"Applying SP hooks to '{module_id}' ({len(submodule)} module(s))") + + for m in submodule: + if isinstance(sp_model_plan, dict): + # Input specification + hook = SequenceParallelSplitHook(sp_model_plan, config) + hook_name = _SP_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(sp_model_plan, (SequenceParallelOutput, list, tuple)): + # Output specification + if isinstance(sp_model_plan, SequenceParallelOutput): + sp_model_plan = [sp_model_plan] + if not all(isinstance(x, SequenceParallelOutput) or x is None for x in sp_model_plan): + raise ValueError(f"Expected SequenceParallelOutput elements, got {sp_model_plan}") + hook = SequenceParallelGatherHook(sp_model_plan, config) + hook_name = _SP_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + raise ValueError(f"Unsupported plan type: {type(sp_model_plan).__name__}") + + registry = HookRegistry.get_or_create(m) + registry.register_hook(hook_name, hook) + + +def remove_sequence_parallel( + module: nn.Module, + plan: SequenceParallelModelPlan, +) -> None: + """Remove sequence parallel hooks from a model. + + Note: This corresponds to `remove_context_parallel` in diffusers. + + Args: + module: The model to remove SP from. + plan: The same plan used when applying SP. + """ + for module_id, sp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] + + for m in submodule: + registry = getattr(m, "_hook_registry", None) + if registry is None: + continue + + if isinstance(sp_model_plan, dict): + hook_name = _SP_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(sp_model_plan, (SequenceParallelOutput, list, tuple)): + hook_name = _SP_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + continue + + registry.remove_hook(hook_name) + + +def enable_sequence_parallel_for_model( + model: nn.Module, + config: SequenceParallelConfig | None = None, +) -> None: + """Enable sequence parallelism for a model using its _sp_plan. + + This is a convenience function that reads the model's _sp_plan attribute + and applies sequence parallelism automatically. + + Note: This corresponds to `enable_context_parallel_for_model` in diffusers, + but uses vLLM-Omni's _sp_plan instead of diffusers' _cp_plan. + + The function performs two main tasks: + 1. Applies _sp_plan hooks to shard inputs and gather outputs + 2. Ensures Attention layers are configured for the correct parallel mode + (handled automatically by vLLM-Omni's forward_context mechanism) + + Args: + model: The model to enable SP for. Must have a _sp_plan attribute. + config: Optional config. If None, uses default based on current + parallel state. + + Raises: + ValueError: If model has no _sp_plan defined. + + Note: + vLLM-Omni supports Ulysses + Ring hybrid parallelism: + - ulysses_degree > 1: Uses All-to-All communication over Q/K/V heads + - ring_degree > 1: Uses Ring attention with K/V passing + - Both > 1: Hybrid mode (Ulysses handles head redistribution, + Ring handles K/V circulation) + """ + from vllm_omni.diffusion.distributed.parallel_state import ( + get_ring_parallel_world_size, + get_ulysses_parallel_world_size, + ) + from vllm_omni.diffusion.distributed.sp_plan import get_sp_plan_from_model + + plan = get_sp_plan_from_model(model) + if plan is None: + raise ValueError( + f"Model {model.__class__.__name__} has no _sp_plan defined. " + f"Define _sp_plan as a class attribute or pass a plan explicitly." + ) + + if config is None: + # Create config from current parallel state + ulysses_degree = get_ulysses_parallel_world_size() + ring_degree = get_ring_parallel_world_size() + config = SequenceParallelConfig( + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + ) + if ulysses_degree > 1 and ring_degree > 1: + mode = "hybrid" + elif ulysses_degree > 1: + mode = "ulysses" + else: + mode = "ring" + logger.info( + f"Created SP config from parallel state: " + f"ulysses_degree={ulysses_degree}, ring_degree={ring_degree}, " + f"mode={mode}" + ) + + apply_sequence_parallel(model, config, plan) + logger.info(f"Enabled sequence parallelism for {model.__class__.__name__}") + + +def disable_sequence_parallel_for_model(model: nn.Module) -> None: + """Disable sequence parallelism for a model. + + Note: This corresponds to `disable_context_parallel_for_model` in diffusers. + + Args: + model: The model to disable SP for. + """ + from vllm_omni.diffusion.distributed.sp_plan import get_sp_plan_from_model + + plan = get_sp_plan_from_model(model) + if plan is not None: + remove_sequence_parallel(model, plan) diff --git a/vllm_omni/diffusion/layers/__init__.py b/vllm_omni/diffusion/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/diffusion/layers/adalayernorm.py b/vllm_omni/diffusion/layers/adalayernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..9a644e1d935af30ce8365a969c00b5919ff54535 --- /dev/null +++ b/vllm_omni/diffusion/layers/adalayernorm.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.layers.custom_op import CustomOp + +logger = init_logger(__name__) + + +class AdaLayerNorm(CustomOp): + """ + AdaLayerNorm: + out = layernorm(x) * (1 + scale) + shift + """ + + def __init__(self, hidden_size: int, elementwise_affine: bool = False, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.elementwise_affine = elementwise_affine + self.hidden_size = hidden_size + self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps) + + def preprocess( + self, + mod_params: torch.Tensor, + index: torch.Tensor = None, + ) -> torch.Tensor: + # shift: b d, scale: b d, gate: b d + shift, scale, gate = mod_params.chunk(3, dim=-1) + + if index is not None: + # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + # So shift, scale, gate have shape [2*actual_batch, d] + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + + # index: [b, l] where b is actual batch size + # Expand to [b, l, 1] to match feature dimension + index_expanded = index.unsqueeze(-1) # [b, l, 1] + + # Expand chunks to [b, 1, d] then broadcast to [b, l, d] + shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d] + shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d] + scale_0_exp = scale_0.unsqueeze(1) + scale_1_exp = scale_1.unsqueeze(1) + gate_0_exp = gate_0.unsqueeze(1) + gate_1_exp = gate_1.unsqueeze(1) + + # Use torch.where to select based on index + shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp) + scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp) + gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp) + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + + return shift_result, scale_result, gate_result + + def forward_cuda( + self, + x: torch.Tensor, + mod_params: torch.Tensor, + index: torch.Tensor = None, + ) -> torch.Tensor: + return self.forward_native(x, mod_params, index) + + def forward_hip( + self, + x: torch.Tensor, + mod_params: torch.Tensor, + index: torch.Tensor = None, + ) -> torch.Tensor: + return self.forward_native(x, mod_params, index) + + def forward_npu( + self, + x: torch.Tensor, + mod_params: torch.Tensor, + index: torch.Tensor = None, + ) -> torch.Tensor: + shift_result, scale_result, gate_result = self.preprocess(mod_params, index) + + import torch_npu + + output = torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.hidden_size], weight=(1 + scale_result), bias=shift_result, eps=self.eps + ) + + return output, gate_result + + def forward_xpu( + self, + x: torch.Tensor, + mod_params: torch.Tensor, + index: torch.Tensor = None, + ) -> torch.Tensor: + return self.forward_native(x, mod_params, index) + + def forward_native( + self, + x: torch.Tensor, + mod_params: torch.Tensor, + index: torch.Tensor = None, + ) -> torch.Tensor: + shift_result, scale_result, gate_result = self.preprocess(mod_params, index) + + return self.layernorm(x) * (1 + scale_result) + shift_result, gate_result diff --git a/vllm_omni/diffusion/layers/custom_op.py b/vllm_omni/diffusion/layers/custom_op.py new file mode 100644 index 0000000000000000000000000000000000000000..321bcbf8ad7048e46a330b18967445031f6ab816 --- /dev/null +++ b/vllm_omni/diffusion/layers/custom_op.py @@ -0,0 +1,53 @@ +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from vllm_omni.platforms import current_omni_platform + + +class CustomOp(nn.Module): + """ + Base class for custom ops. + Dispatches the forward method to the appropriate backend. + """ + + def __init__(self) -> None: + super().__init__() + self._forward_method = self.dispatch_forward() + + def dispatch_forward(self) -> Callable: + if current_omni_platform.is_rocm(): + return self.forward_hip + elif current_omni_platform.is_cuda(): + return self.forward_cuda + elif current_omni_platform.is_npu(): + return self.forward_npu + elif current_omni_platform.is_xpu(): + return self.forward_xpu + else: + return self.forward_native + + def forward(self, *args, **kwargs) -> Any: + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + """PyTorch-native implementation of the forward method. + This method is optional. If implemented, it can be used with compilers + such as torch.compile or PyTorch XLA. Also, it can be used for testing + purposes. + """ + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_npu(self, *args, **kwargs): + raise NotImplementedError + + def forward_xpu(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + # By default, we assume that HIP ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) diff --git a/vllm_omni/diffusion/layers/rope.py b/vllm_omni/diffusion/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..461c25652e479ac8e3089604727174a4e24f0fa7 --- /dev/null +++ b/vllm_omni/diffusion/layers/rope.py @@ -0,0 +1,159 @@ +from importlib.util import find_spec + +import torch +from einops import rearrange, repeat +from vllm.logger import init_logger + +from vllm_omni.diffusion.layers.custom_op import CustomOp + +logger = init_logger(__name__) + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +def apply_rotary_emb_mindiesd( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, + half_head_dim: bool = True, # if true, size of sin and cos is (B, S, D/2), otherwise (B, S, D) +) -> torch.Tensor: + from mindiesd import rotary_position_embedding + + if cos.dim() == 3: + # (B, S, D/2) -> (S, D/2) + cos = cos[0] + sin = sin[0] + + if interleaved: + # if last dim of sin and cos is D/2, expand to (S, D) to adapt to mindiesd operators + if half_head_dim: + seqlen = cos.shape[0] + sin = sin.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + cos = cos.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", head_first=False, fused=True) + else: + if half_head_dim: + seqlen = cos.shape[0] + sin = sin.unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + cos = cos.unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_half", head_first=False, fused=True) + + +class RotaryEmbedding(CustomOp): + """ + rotary positional embedding. + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + """ + + def __init__( + self, + is_neox_style: bool = False, + ) -> None: + super().__init__() + self.is_neox_style = is_neox_style + self.interleaved = not is_neox_style + self.apply_rotary_emb_flash_attn = None + if find_spec("flash_attn") is not None: + from flash_attn.ops.triton.rotary import apply_rotary + + self.apply_rotary_emb_flash_attn = apply_rotary + + def forward_cuda( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + if cos.dim() == 3: + # (B, S, D/2) -> (S, D/2) + cos = cos[0] + sin = sin[0] + + return apply_rotary_emb( + x, + cos, + sin, + interleaved=self.interleaved, + ) + + def forward_hip( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + if self.apply_rotary_emb_flash_attn is None: + return self.forward_cuda(x, cos, sin) + + if cos.dim() == 3: + # (B, S, D/2) -> (S, D/2) + cos = cos[0] + sin = sin[0] + + return self.apply_rotary_emb_flash_attn( + x, + cos, + sin, + interleaved=self.interleaved, + ) + + def forward_npu( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + if find_spec("mindiesd"): + return apply_rotary_emb_mindiesd(x, cos, sin, self.interleaved) + else: + return self.forward_native(x, cos, sin) + + def forward_xpu( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + return self.forward_native(x, cos, sin) + + def forward_native( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + return apply_rotary_emb_torch( + x, + cos, + sin, + interleaved=self.interleaved, + ) diff --git a/vllm_omni/diffusion/lora/__init__.py b/vllm_omni/diffusion/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..353a2c6bee6b1c6cd77d95d63fa72415e79342cb --- /dev/null +++ b/vllm_omni/diffusion/lora/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager + +__all__ = ["DiffusionLoRAManager"] diff --git a/vllm_omni/diffusion/lora/layers/__init__.py b/vllm_omni/diffusion/lora/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab501f105f35ef80d41eb6be9044f2e9d06793c3 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .base_linear import DiffusionBaseLinearLayerWithLoRA +from .column_parallel_linear import ( + DiffusionColumnParallelLinearWithLoRA, + DiffusionMergedColumnParallelLinearWithLoRA, + DiffusionMergedQKVParallelLinearWithLoRA, + DiffusionQKVParallelLinearWithLoRA, +) +from .replicated_linear import DiffusionReplicatedLinearWithLoRA +from .row_parallel_linear import DiffusionRowParallelLinearWithLoRA + +__all__ = [ + "DiffusionBaseLinearLayerWithLoRA", + "DiffusionReplicatedLinearWithLoRA", + "DiffusionColumnParallelLinearWithLoRA", + "DiffusionMergedColumnParallelLinearWithLoRA", + "DiffusionRowParallelLinearWithLoRA", + "DiffusionQKVParallelLinearWithLoRA", + "DiffusionMergedQKVParallelLinearWithLoRA", +] diff --git a/vllm_omni/diffusion/lora/layers/base_linear.py b/vllm_omni/diffusion/lora/layers/base_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..fe32868d0831bb3b7fca82038100cfd363b4a827 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/base_linear.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch +from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA + + +class DiffusionBaseLinearLayerWithLoRA(BaseLinearLayerWithLoRA): + """ + Diffusion-specific base that overrides apply() to use direct torch matmul + instead of punica_wrapper. + + punica_wrapper is used to hold multiple LoRA slots and slices efficiently. + + This matches the semantics of PunicaWrapperGPU.add_lora_linear(): + - Shrink: buffer = (x @ lora_a.T) + - Expand: y += buffer @ lora_b.T + + All other functionality (weight management, TP slicing, forward logic) + is inherited from vLLM's BaseLinearLayerWithLoRA. + """ + + def create_lora_weights( + self, + max_loras: int, + lora_config, + model_config=None, + ) -> None: + super().create_lora_weights(max_loras, lora_config, model_config) + # Keep a direct reference for attribute forwarding: `base_layer` is a + # registered submodule (stored under `_modules`), so direct access via + # `object.__getattribute__` will not find it. We stash a ref in + # `__dict__` for robust lookups in `__getattr__`. + modules = object.__getattribute__(self, "_modules") + base_layer = modules.get("base_layer") or object.__getattribute__(self, "__dict__").get("base_layer") + object.__setattr__(self, "_diffusion_base_layer_ref", base_layer) + n_slices = getattr(self, "n_slices", 1) + self._diffusion_lora_active_slices = (False,) * int(n_slices) + + def reset_lora(self, index: int): + super().reset_lora(index) + n_slices = getattr(self, "n_slices", 1) + self._diffusion_lora_active_slices = (False,) * int(n_slices) + + def set_lora( + self, + index: int, + lora_a: torch.Tensor | list[torch.Tensor | None], + lora_b: torch.Tensor | list[torch.Tensor | None], + ): + super().set_lora(index, lora_a, lora_b) # type: ignore[arg-type] + + n_slices = getattr(self, "n_slices", 1) + if isinstance(lora_a, list) or isinstance(lora_b, list): + assert isinstance(lora_a, list) + assert isinstance(lora_b, list) + active_slices = [] + for a_i, b_i in zip(lora_a[:n_slices], lora_b[:n_slices]): + active_slices.append(a_i is not None and b_i is not None) + if len(active_slices) < n_slices: + active_slices.extend([False] * (n_slices - len(active_slices))) + self._diffusion_lora_active_slices = tuple(active_slices) + else: + # Single-slice layer. + self._diffusion_lora_active_slices = (True,) + + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + """ + override: Use simple matmul instead of punica_wrapper.add_lora_linear(). + + This matches the exact computation in PunicaWrapperGPU.add_lora_linear() + for the single-LoRA case. For packed projections (e.g. fused QKV), we + apply LoRA per-slice using `output_slices`. + """ + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + if not hasattr(self, "lora_a_stacked") or not hasattr(self, "lora_b_stacked"): + return output + if not self.lora_a_stacked or not self.lora_b_stacked: + return output + # Fast path: if no LoRA is active for this layer, skip matmuls. + active_slices = getattr(self, "_diffusion_lora_active_slices", None) + if active_slices is not None and not any(active_slices): + return output + + # In fully-sharded LoRA mode, vLLM uses an all-gather between shrink and + # expand for ColumnParallelLinear variants. This diffusion path doesn't + # implement that communication yet. + if getattr(self, "lora_config", None) is not None: + if self.lora_config.fully_sharded_loras and self.tp_size > 1: + raise NotImplementedError( + "Diffusion LoRA apply() does not support fully_sharded_loras with tensor parallelism yet." + ) + + original_shape = output.shape + x_flat = x.reshape(-1, x.shape[-1]) + y_flat = output.reshape(-1, output.shape[-1]) + + output_slices = getattr(self, "output_slices", None) + if output_slices is None: + # Fallback: infer slice sizes from the allocated tensors. + output_slices = tuple(lora_b.shape[2] for lora_b in self.lora_b_stacked) + + if len(output_slices) != len(self.lora_a_stacked) or len(output_slices) != len(self.lora_b_stacked): + raise RuntimeError( + "LoRA slice metadata mismatch: " + f"output_slices={len(output_slices)}, " + f"lora_a_stacked={len(self.lora_a_stacked)}, " + f"lora_b_stacked={len(self.lora_b_stacked)}" + ) + + offset = 0 + for slice_idx, slice_size in enumerate(output_slices): + if active_slices is not None and slice_idx < len(active_slices) and not active_slices[slice_idx]: + offset += slice_size + continue + + A = self.lora_a_stacked[slice_idx][0, 0, :, :] # (rank, in_dim) + B = self.lora_b_stacked[slice_idx][0, 0, :, :] # (out_dim, rank) + + if A.numel() == 0 or B.numel() == 0: + offset += slice_size + continue + + # LoRA shrink & expand as in add_lora_linear(): + # buffer = (x @ A.T) + # y += buffer @ B.T + delta = (x_flat @ A.t()) @ B.t() + y_flat[:, offset : offset + slice_size] = y_flat[:, offset : offset + slice_size] + delta + offset += slice_size + + return y_flat.view(original_shape) + + def __getattr__(self, name: str): + # The diffusion model implementations may access attributes directly + # from linear layers (e.g. QKVParallelLinear.num_heads). vLLM's LoRA + # wrappers don't forward these attributes by default, so we delegate + # missing attribute lookups to the underlying base_layer. + try: + return super().__getattr__(name) + except AttributeError as exc: + base_layer = object.__getattribute__(self, "__dict__").get("_diffusion_base_layer_ref") + if base_layer is None: + base_layer = object.__getattribute__(self, "_modules").get("base_layer") + if base_layer is None: + raise exc + try: + return getattr(base_layer, name) + except AttributeError: + raise exc diff --git a/vllm_omni/diffusion/lora/layers/column_parallel_linear.py b/vllm_omni/diffusion/lora/layers/column_parallel_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..27ac94e61ed44ea17b088563a96426331daf2df9 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/column_parallel_linear.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.lora.layers.column_parallel_linear import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, +) + +from .base_linear import DiffusionBaseLinearLayerWithLoRA + + +class DiffusionColumnParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + ColumnParallelLinearWithLoRA, +): + """ + Diffusion ColumnParallelLinear with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass + + +class DiffusionMergedColumnParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + MergedColumnParallelLinearWithLoRA, +): + """ + Diffusion MergedColumnParallelLinear (gate_up_proj) with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass + + +class DiffusionQKVParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + QKVParallelLinearWithLoRA, +): + """ + Diffusion QKVParallelLinear with single LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass + + +class DiffusionMergedQKVParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + MergedQKVParallelLinearWithLoRA, +): + """ + Diffusion MergedQKVParallelLinear (to_qkv) with 3 LoRAs. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass diff --git a/vllm_omni/diffusion/lora/layers/replicated_linear.py b/vllm_omni/diffusion/lora/layers/replicated_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e6574a04f693185b80e5c560f8c5a45f50054cd0 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/replicated_linear.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA + +from .base_linear import DiffusionBaseLinearLayerWithLoRA + + +class DiffusionReplicatedLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + ReplicatedLinearWithLoRA, +): + """ + Diffusion ReplicatedLinear with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass diff --git a/vllm_omni/diffusion/lora/layers/row_parallel_linear.py b/vllm_omni/diffusion/lora/layers/row_parallel_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2119092130bb93886eecdbeb832c8c5d63cd10 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/row_parallel_linear.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.lora.layers.row_parallel_linear import RowParallelLinearWithLoRA + +from .base_linear import DiffusionBaseLinearLayerWithLoRA + + +class DiffusionRowParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + RowParallelLinearWithLoRA, +): + """ + Diffusion RowParallelLinear with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass diff --git a/vllm_omni/diffusion/lora/manager.py b/vllm_omni/diffusion/lora/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..7fad1b9e758daf81c114cf1982cbd5d32742a991 --- /dev/null +++ b/vllm_omni/diffusion/lora/manager.py @@ -0,0 +1,631 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from collections import OrderedDict + +import torch +import torch.nn as nn +from vllm.logger import init_logger +from vllm.lora.layers import BaseLayerWithLoRA +from vllm.lora.lora_model import LoRAModel +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.request import LoRARequest +from vllm.lora.utils import ( + get_adapter_absolute_path, + get_supported_lora_modules, + replace_submodule, +) +from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear + +from vllm_omni.config.lora import LoRAConfig +from vllm_omni.diffusion.lora.utils import ( + _expand_expected_modules_for_packed_layers, + _match_target_modules, + from_layer_diffusion, +) +from vllm_omni.lora.utils import stable_lora_int_id + +logger = init_logger(__name__) + + +class DiffusionLoRAManager: + """Manager for LoRA adapters in diffusion models. + + Reuses vLLM's LoRA infrastructure, adapted for diffusion pipelines. + Uses LRU cache management similar to LRUCacheLoRAModelManager. + """ + + def __init__( + self, + pipeline: nn.Module, + device: torch.device, + dtype: torch.dtype, + max_cached_adapters: int = 1, + lora_path: str | None = None, + lora_scale: float = 1.0, + ): + """ + Initialize the DiffusionLoRAManager. + + Args: + max_cached_adapters: Maximum number of LoRA adapters to keep in the + CPU-side cache (LRU). This mirrors vLLM's `max_cpu_loras` and is + exposed to users via `OmniDiffusionConfig.max_cpu_loras`. + """ + self.pipeline = pipeline + self.device = device + self.dtype = dtype + + # Cache supported/expected module suffixes once, before any layer + # replacement happens. After LoRA layers are injected, the original + # LinearBase layers become submodules named "*.base_layer", and calling + # vLLM's get_supported_lora_modules() again would incorrectly yield + # "base_layer" instead of the real target module suffixes. + self._supported_lora_modules = self._compute_supported_lora_modules() + self._packed_modules_mapping = self._compute_packed_modules_mapping() + self._expected_lora_modules = _expand_expected_modules_for_packed_layers( + self._supported_lora_modules, + self._packed_modules_mapping, + ) + + # LRU-style cache management + self.max_cached_adapters = max_cached_adapters # max_cpu_loras + self._registered_adapters: dict[int, LoRAModel] = {} # adapter_id -> LoRAModel + self._active_adapter_id: int | None = None + self._adapter_scales: dict[int, float] = {} # adapter_id -> external scale + + # LRU cache tracking (adapter_id -> last_used_time) + self._adapter_access_order: OrderedDict[int, float] = OrderedDict() + # Pinned adapters are not evicted + self._pinned_adapters: set[int] = set() + + # track replaced modules + # key: full module name (component.module.path); value: LoRA layer + self._lora_modules: dict[str, BaseLayerWithLoRA] = {} + # Track the maximum LoRA rank we've allocated buffers for. + self._max_lora_rank: int = 0 + + logger.info( + "Initializing DiffusionLoRAManager: device=%s, dtype=%s, max_cached_adapters=%d, static_lora_path=%s", + device, + dtype, + max_cached_adapters, + lora_path, + ) + + if lora_path is not None: + logger.info("Loading LoRA during initialization from %s with scale %.2f", lora_path, lora_scale) + init_request = LoRARequest( + lora_name="static", + lora_int_id=stable_lora_int_id(lora_path), + lora_path=lora_path, + ) + self.set_active_adapter(init_request, lora_scale) + + def _compute_supported_lora_modules(self) -> set[str]: + """Compute supported LoRA module suffixes for this pipeline. + + vLLM's get_supported_lora_modules() returns suffixes for LinearBase + modules. After this manager replaces layers with BaseLayerWithLoRA + wrappers, those LinearBase modules become nested under ".base_layer", + which would cause get_supported_lora_modules() to return "base_layer". + To make adapter loading stable across multiple adapters, we also accept + suffixes from existing BaseLayerWithLoRA wrappers and drop "base_layer" + when appropriate. + """ + supported = set(get_supported_lora_modules(self.pipeline)) + + has_lora_wrappers = False + for name, module in self.pipeline.named_modules(): + if isinstance(module, BaseLayerWithLoRA): + has_lora_wrappers = True + supported.add(name.split(".")[-1]) + + if has_lora_wrappers: + supported.discard("base_layer") + + return supported + + def _compute_packed_modules_mapping(self) -> dict[str, list[str]]: + """Collect packed->sublayer mappings from the diffusion model. + + vLLM models declare `packed_modules_mapping` on the model class. For + diffusion pipelines, we attach the same mapping on the transformer + module(s) that implement packed (fused) projections, so LoRA loading can + accept checkpoints trained against the logical sub-projections. + """ + mapping: dict[str, list[str]] = {} + for module in self.pipeline.modules(): + packed = getattr(module, "packed_modules_mapping", None) + if not isinstance(packed, dict): + continue + for packed_name, sub_names in packed.items(): + if not isinstance(packed_name, str) or not packed_name: + continue + if not isinstance(sub_names, (list, tuple)) or not all(isinstance(s, str) for s in sub_names): + continue + sub_names_list = list(sub_names) + if not sub_names_list: + continue + + existing = mapping.get(packed_name) + if existing is None: + mapping[packed_name] = sub_names_list + elif existing != sub_names_list: + logger.warning( + "Conflicting packed_modules_mapping for %s: %s vs %s; using %s", + packed_name, + existing, + sub_names_list, + existing, + ) + + return mapping + + def _get_packed_sublayer_suffixes(self, packed_module_suffix: str, n_slices: int) -> list[str] | None: + sub_suffixes = self._packed_modules_mapping.get(packed_module_suffix) + if not sub_suffixes: + return None + if len(sub_suffixes) != n_slices: + logger.warning( + "packed_modules_mapping[%s] has %d slices but layer expects %d; skipping sublayer lookup", + packed_module_suffix, + len(sub_suffixes), + n_slices, + ) + return None + return sub_suffixes + + def set_active_adapter(self, lora_request: LoRARequest | None, lora_scale: float = 1.0) -> None: + """Set the active LoRA adapter for the pipeline. + + Args: + lora_request: The LoRA request, or None to deactivate all adapters. + lora_scale: The external scale for the LoRA adapter. + """ + if lora_request is None: + logger.debug("No lora_request provided, deactivating all LoRA adapters") + self._deactivate_all_adapters() + return + + adapter_id = lora_request.lora_int_id + logger.debug( + "Setting active adapter: id=%d, name=%s, path=%s, scale=%.2f, cache_size=%d/%d", + adapter_id, + lora_request.lora_name, + lora_request.lora_path, + lora_scale, + len(self._registered_adapters), + self.max_cached_adapters, + ) + if adapter_id not in self._registered_adapters: + logger.info("Loading new adapter: id=%d, name=%s", adapter_id, lora_request.lora_name) + self.add_adapter(lora_request, lora_scale) + else: + logger.debug("Adapter %d already loaded, activating", adapter_id) + + # update access order + self._adapter_scales[adapter_id] = lora_scale + self._adapter_access_order[adapter_id] = time.time() + self._adapter_access_order.move_to_end(adapter_id) + + self._activate_adapter(adapter_id) + + def _load_adapter( + self, + lora_request: LoRARequest, + ) -> tuple[LoRAModel, PEFTHelper]: + if not self._expected_lora_modules: + raise ValueError("No supported LoRA modules found in the diffusion pipeline.") + + logger.debug("Supported LoRA modules: %s", self._expected_lora_modules) + + lora_path = get_adapter_absolute_path(lora_request.lora_path) + logger.debug("Resolved LoRA path: %s", lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, + max_position_embeddings=None, # no need in diffusion + tensorizer_config_dict=lora_request.tensorizer_config_dict, + ) + + logger.info( + "Loaded PEFT config: r=%d, lora_alpha=%d, target_modules=%s", + peft_helper.r, + peft_helper.lora_alpha, + peft_helper.target_modules, + ) + + lora_model = LoRAModel.from_local_checkpoint( + lora_path, + expected_lora_modules=self._expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", # consistent w/ vllm's behavior + dtype=self.dtype, + model_vocab_size=None, + tensorizer_config_dict=lora_request.tensorizer_config_dict, + weights_mapper=None, + ) + + logger.info( + "Loaded LoRA model: id=%d, num_modules=%d, modules=%s", + lora_model.id, + len(lora_model.loras), + list(lora_model.loras.keys()), + ) + + for lora in lora_model.loras.values(): + lora.optimize() # ref: _create_merged_loras_inplace, internal scaling + + return lora_model, peft_helper + + def _get_packed_modules_list(self, module: nn.Module) -> list[str]: + """Return a packed_modules_list suitable for vLLM LoRA can_replace_layer(). + + Diffusion transformers frequently use packed projection layers like + QKVParallelLinear (fused QKV). vLLM's LoRA replacement logic relies on + `packed_modules_list` length to decide between single-slice vs packed + LoRA layer implementations. + """ + if isinstance(module, QKVParallelLinear): + # Treat diffusion QKV as a 3-slice packed projection by default. + return ["q", "k", "v"] + if isinstance(module, MergedColumnParallelLinear): + # 2-slice packed projection (e.g. fused MLP projections). + return ["0", "1"] + return [] + + def _replace_layers_with_lora(self, peft_helper: PEFTHelper) -> None: + self._ensure_max_lora_rank(peft_helper.r) + + target_modules = getattr(peft_helper, "target_modules", None) + target_modules_list: list[str] | None = None + target_modules_pattern: str | None = None + if isinstance(target_modules, str) and target_modules: + target_modules_pattern = target_modules + elif isinstance(target_modules, list) and target_modules: + target_modules_list = target_modules + + def _matches_target(module_name: str) -> bool: + if target_modules_pattern is not None: + import regex as re + + return re.search(target_modules_pattern, module_name) is not None + if target_modules_list is None: + return True + return _match_target_modules(module_name, target_modules_list) + + # dummy lora config + lora_config = LoRAConfig( + max_lora_rank=self._max_lora_rank, + max_loras=1, + max_cpu_loras=self.max_cached_adapters, + lora_dtype=self.dtype, + fully_sharded_loras=False, + ) + + for component_name in ("transformer", "transformer_2", "dit"): + if not hasattr(self.pipeline, component_name): + continue + component = getattr(self.pipeline, component_name) + if not isinstance(component, nn.Module): + continue + + for module_name, module in component.named_modules(remove_duplicate=False): + # Don't recurse into already-replaced LoRA wrappers. Their + # original LinearBase lives under "base_layer", and replacing + # that again would nest LoRA wrappers and break execution. + if isinstance(module, BaseLayerWithLoRA) or "base_layer" in module_name.split("."): + continue + + full_module_name = f"{component_name}.{module_name}" + if full_module_name in self._lora_modules: + logger.debug("Layer %s already replaced, skipping", full_module_name) + continue + + packed_modules_list = self._get_packed_modules_list(module) + if target_modules_pattern is not None or target_modules_list is not None: + should_replace = _matches_target(full_module_name) + if not should_replace and len(packed_modules_list) > 1: + prefix, _, packed_suffix = full_module_name.rpartition(".") + sub_suffixes = self._get_packed_sublayer_suffixes(packed_suffix, len(packed_modules_list)) + if sub_suffixes is not None: + for sub_suffix in sub_suffixes: + sub_full_name = f"{prefix}.{sub_suffix}" if prefix else sub_suffix + if _matches_target(sub_full_name): + should_replace = True + break + + if not should_replace: + continue + + lora_layer = from_layer_diffusion( + layer=module, + max_loras=1, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=None, + ) + + if lora_layer is not module and isinstance(lora_layer, BaseLayerWithLoRA): + replace_submodule(component, module_name, lora_layer) + self._lora_modules[full_module_name] = lora_layer + logger.debug("Replaced layer: %s -> %s", full_module_name, type(lora_layer).__name__) + + def _ensure_max_lora_rank(self, min_rank: int) -> None: + """Ensure LoRA buffers can accommodate adapters up to `min_rank`. + + We allocate per-layer LoRA buffers once when we first replace layers. + If a later adapter has a larger rank, we need to reinitialize those + buffers and re-apply the currently active adapter. + """ + if min_rank <= self._max_lora_rank: + return + + if min_rank <= 0: + raise ValueError(f"Invalid LoRA rank: {min_rank}") + + logger.info("Increasing max LoRA rank: %d -> %d", self._max_lora_rank, min_rank) + self._max_lora_rank = min_rank + + if not self._lora_modules: + return + + lora_config = LoRAConfig( + max_lora_rank=self._max_lora_rank, + max_loras=1, + max_cpu_loras=self.max_cached_adapters, + lora_dtype=self.dtype, + fully_sharded_loras=False, + ) + + # Recreate per-layer buffers with the new maximum rank. + for lora_layer in self._lora_modules.values(): + lora_layer.create_lora_weights(max_loras=1, lora_config=lora_config, model_config=None) + + # Re-apply active adapter if needed (buffers were reset). + if self._active_adapter_id is not None: + active_id = self._active_adapter_id + self._active_adapter_id = None + self._activate_adapter(active_id) + + def _get_lora_weights( + self, + lora_model: LoRAModel, + full_module_name: str, + ) -> LoRALayerWeights | PackedLoRALayerWeights | None: + """Best-effort lookup for LoRA weights by name. + + Tries: + - Full module name (e.g. transformer.blocks.0.attn.to_qkv) + - Relative name without the top-level component (e.g. blocks.0.attn.to_qkv) + - Suffix-only name (e.g. to_qkv) + """ + lora_weights = lora_model.get_lora(full_module_name) + if lora_weights is not None: + return lora_weights + + component_relative_name = full_module_name.split(".", 1)[-1] if "." in full_module_name else full_module_name + lora_weights = lora_model.get_lora(component_relative_name) + if lora_weights is not None: + return lora_weights + + module_suffix = full_module_name.split(".")[-1] + return lora_model.get_lora(module_suffix) + + def _activate_adapter(self, adapter_id: int) -> None: + if self._active_adapter_id == adapter_id: + logger.debug("Adapter %d already active, skipping", adapter_id) + return + + logger.info("Activating adapter: id=%d", adapter_id) + lora_model = self._registered_adapters[adapter_id] + + # activate weights in each LoRA layer + for full_module_name, lora_layer in self._lora_modules.items(): + lora_weights = self._get_lora_weights(lora_model, full_module_name) + + if lora_weights is None: + n_slices = getattr(lora_layer, "n_slices", 1) + if n_slices > 1: + prefix, _, packed_suffix = full_module_name.rpartition(".") + sub_suffixes = self._get_packed_sublayer_suffixes(packed_suffix, n_slices) + if sub_suffixes is None: + lora_layer.reset_lora(0) + continue + + sub_loras: list[LoRALayerWeights | None] = [] + any_found = False + for sub_suffix in sub_suffixes: + sub_full_name = f"{prefix}.{sub_suffix}" if prefix else sub_suffix + sub_lora = self._get_lora_weights(lora_model, sub_full_name) + if sub_lora is not None: + any_found = True + # Packed layers expect plain (non-packed) subloras. + if isinstance(sub_lora, PackedLoRALayerWeights): + sub_lora = None + sub_loras.append(sub_lora if isinstance(sub_lora, LoRALayerWeights) else None) + + if not any_found: + lora_layer.reset_lora(0) + continue + + scale = self._adapter_scales.get(adapter_id, 1.0) + lora_a_list: list[torch.Tensor | None] = [] + lora_b_list: list[torch.Tensor | None] = [] + for sub_lora in sub_loras: + if sub_lora is None: + lora_a_list.append(None) + lora_b_list.append(None) + continue + lora_a_list.append(sub_lora.lora_a) + lora_b_list.append(sub_lora.lora_b * scale) + + lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) + logger.debug( + "Activated packed LoRA for %s via submodules=%s (scale=%.2f)", + full_module_name, + sub_suffixes, + scale, + ) + else: + lora_layer.reset_lora(0) + continue + + scale = self._adapter_scales.get(adapter_id, 1.0) + + # Packed LoRA weights already provide per-slice tensors. + if isinstance(lora_weights, PackedLoRALayerWeights): + lora_a_list = lora_weights.lora_a + lora_b_list = [ + None if b is None else b * scale # type: ignore[operator] + for b in lora_weights.lora_b + ] + lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) + logger.debug( + "Activated packed LoRA for %s (scale=%.2f)", + full_module_name, + scale, + ) + continue + + # Fused (non-packed) weights: if the layer is multi-slice, split B. + n_slices = getattr(lora_layer, "n_slices", 1) + if n_slices > 1: + output_slices = getattr(lora_layer, "output_slices", None) + if output_slices is None: + lora_layer.reset_lora(0) + continue + + total = sum(output_slices) + if lora_weights.lora_b.shape[0] != total: + logger.warning( + "Skipping LoRA for %s due to shape mismatch: lora_b[0]=%d != sum(output_slices)=%d", + full_module_name, + lora_weights.lora_b.shape[0], + total, + ) + lora_layer.reset_lora(0) + continue + + b_splits = list(torch.split(lora_weights.lora_b, list(output_slices), dim=0)) + lora_a_list = [lora_weights.lora_a] * n_slices + lora_b_list = [b * scale for b in b_splits] + lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) + logger.debug( + "Activated fused LoRA for packed layer %s (scale=%.2f)", + full_module_name, + scale, + ) + continue + + scaled_lora_b = lora_weights.lora_b * scale + lora_layer.set_lora(index=0, lora_a=lora_weights.lora_a, lora_b=scaled_lora_b) + logger.debug( + "Activated LoRA for %s: lora_a shape=%s, lora_b shape=%s, scale=%.2f", + full_module_name, + lora_weights.lora_a.shape, + lora_weights.lora_b.shape, + scale, + ) + + self._active_adapter_id = adapter_id + + def _deactivate_all_adapters(self) -> None: + logger.info("Deactivating all adapters: %d layers", len(self._lora_modules)) + for lora_layer in self._lora_modules.values(): + lora_layer.reset_lora(0) + self._active_adapter_id = None + logger.debug("All adapters deactivated") + + def _evict_if_needed(self) -> None: + while len(self._registered_adapters) > self.max_cached_adapters: + # Pick LRU among non-pinned adapters + evict_candidates = [aid for aid in self._adapter_access_order.keys() if aid not in self._pinned_adapters] + if not evict_candidates: + logger.warning( + "Cache full (%d) but all adapters are pinned; cannot evict. " + "Increase max_cached_adapters or unpin adapters.", + self.max_cached_adapters, + ) + break + + lru_adapter_id = evict_candidates[0] + logger.info( + "Evicting LRU adapter: id=%d (cache: %d/%d)", + lru_adapter_id, + len(self._registered_adapters), + self.max_cached_adapters, + ) + self.remove_adapter(lru_adapter_id) + + def add_adapter(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """ + Add a new adapter to the cache without activating it. + """ + adapter_id = lora_request.lora_int_id + + if adapter_id in self._registered_adapters: + logger.debug("Adapter %d already registered, skipping", adapter_id) + return False + + logger.info("Adding new adapter: id=%d, name=%s", adapter_id, lora_request.lora_name) + lora_model, peft_helper = self._load_adapter(lora_request) + self._registered_adapters[adapter_id] = lora_model + self._adapter_scales[adapter_id] = lora_scale + + self._replace_layers_with_lora(peft_helper) + + self._adapter_access_order[adapter_id] = time.time() + self._adapter_access_order.move_to_end(adapter_id) + + # evict if cache full + self._evict_if_needed() + + logger.debug( + "Adapter %d added, cache size: %d/%d", adapter_id, len(self._registered_adapters), self.max_cached_adapters + ) + return True + + def remove_adapter(self, adapter_id: int) -> bool: + """ + Remove an adapter from the cache. + """ + if adapter_id not in self._registered_adapters: + logger.debug("Adapter %d not found, cannot remove", adapter_id) + return False + + logger.info("Removing adapter: id=%d", adapter_id) + if self._active_adapter_id == adapter_id: + self._deactivate_all_adapters() + + del self._registered_adapters[adapter_id] + self._adapter_scales.pop(adapter_id, None) + self._adapter_access_order.pop(adapter_id, None) + self._pinned_adapters.discard(adapter_id) + logger.debug( + "Adapter %d removed, cache size: %d/%d", + adapter_id, + len(self._registered_adapters), + self.max_cached_adapters, + ) + return True + + def list_adapters(self) -> list[int]: + """Return list of registered adapter ids.""" + return list(self._registered_adapters.keys()) + + def pin_adapter(self, adapter_id: int) -> bool: + """Mark an adapter as pinned so it will not be evicted.""" + if adapter_id not in self._registered_adapters: + logger.debug("Adapter %d not found, cannot pin", adapter_id) + return False + self._pinned_adapters.add(adapter_id) + # Touch access order so it is most recently used + self._adapter_access_order[adapter_id] = time.time() + self._adapter_access_order.move_to_end(adapter_id) + logger.info("Pinned adapter id=%d (won't be evicted)", adapter_id) + return True diff --git a/vllm_omni/diffusion/lora/utils.py b/vllm_omni/diffusion/lora/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1baea34df43136ce8bad96634498b3cea06d70 --- /dev/null +++ b/vllm_omni/diffusion/lora/utils.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm_omni.config.lora import LoRAConfig +from vllm_omni.diffusion.lora.layers import ( + DiffusionColumnParallelLinearWithLoRA, + DiffusionMergedColumnParallelLinearWithLoRA, + DiffusionMergedQKVParallelLinearWithLoRA, + DiffusionQKVParallelLinearWithLoRA, + DiffusionReplicatedLinearWithLoRA, + DiffusionRowParallelLinearWithLoRA, +) + + +def _match_target_modules(module_name: str, target_modules: list[str]) -> bool: + """from vllm/lora/model_manager.py _match_target_modules, helper function""" + import regex as re + + return any( + re.match(rf".*\.{target_module}$", module_name) or target_module == module_name + for target_module in target_modules + ) + + +def _expand_expected_modules_for_packed_layers( + supported_modules: set[str], + packed_modules_mapping: dict[str, list[str]] | None, +) -> set[str]: + """Expand expected LoRA module suffixes for packed (fused) projections. + + Some diffusion models use packed projections like `to_qkv` or `w13`, while + LoRA checkpoints are typically saved against the logical sub-projections + (e.g. `to_q`/`to_k`/`to_v`, `w1`/`w3`). The packed layer name is present in + `supported_modules`, but the sublayer names are not. Expanding the set + ensures these sublayer keys are not dropped when loading a LoRA checkpoint. + + The packed→sublayer mapping is model-specific (see each diffusion model's + `packed_modules_mapping`) so new packed layers are added alongside the model + implementation rather than hard-coded in the LoRA framework. + """ + expanded = set(supported_modules) + if not packed_modules_mapping: + return expanded + + for packed_name, sub_names in packed_modules_mapping.items(): + if packed_name in supported_modules: + expanded.update(sub_names) + + return expanded + + +def from_layer_diffusion( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: list[str], + model_config: PretrainedConfig | None = None, +) -> nn.Module: + """ + Diffusion-specific layer replacement. similar to vLLM's `from_layer` + """ + diffusion_lora_classes = [ + DiffusionMergedQKVParallelLinearWithLoRA, + DiffusionQKVParallelLinearWithLoRA, + DiffusionMergedColumnParallelLinearWithLoRA, + DiffusionColumnParallelLinearWithLoRA, + DiffusionRowParallelLinearWithLoRA, + DiffusionReplicatedLinearWithLoRA, + ] + + for lora_cls in diffusion_lora_classes: + if lora_cls.can_replace_layer( + source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + ): + instance = lora_cls(layer) # type: ignore[arg-type] + instance.create_lora_weights(max_loras, lora_config, model_config) + return instance + + return layer diff --git a/vllm_omni/diffusion/model_loader/__init__.py b/vllm_omni/diffusion/model_loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..892954ce9ea621c805a8e9a637b71bb5f8fcff3e --- /dev/null +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import glob +import os +import time +from collections.abc import Generator, Iterable +from pathlib import Path +from typing import cast + +import torch +from torch import nn +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, + download_weights_from_hf, + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + maybe_download_from_modelscope, + safetensors_weights_iterator, +) +from vllm.utils.torch_utils import set_default_torch_dtype + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.registry import initialize_model + +logger = init_logger(__name__) + + +MODEL_INDEX = "model_index.json" +DIFFUSION_MODEL_WEIGHTS_INDEX = "diffusion_pytorch_model.safetensors.index.json" + + +class DiffusersPipelineLoader: + """Model loader that can load diffusers pipeline components from disk.""" + + # default number of thread when enable multithread weight loading + DEFAULT_NUM_THREADS = 8 + + @dataclasses.dataclass + class ComponentSource: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + subfolder: str | None + """The subfolder inside the model repo.""" + + revision: str | None + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + allow_patterns_overrides: list[str] | None = None + """If defined, weights will load exclusively using these patterns.""" + + counter_before_loading_weights: float = 0.0 + counter_after_loading_weights: float = 0.0 + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + # TODO(Isotr0py): Enable multithreaded weight loading + # extra_config = load_config.model_loader_extra_config + # allowed_keys = {"enable_multithread_load", "num_threads"} + # unexpected_keys = set(extra_config.keys()) - allowed_keys + + # if unexpected_keys: + # raise ValueError( + # f"Unexpected extra config keys for load format {load_config.load_format}: {unexpected_keys}" + # ) + + def _prepare_weights( + self, + model_name_or_path: Path, + subfolder: str | None, + revision: str | None, + fall_back_to_pt: bool, + allow_patterns_overrides: list[str] | None, + ) -> tuple[str, list[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = maybe_download_from_modelscope(model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = DIFFUSION_MODEL_WEIGHTS_INDEX + index_file_with_subfolder = f"{subfolder}/{index_file}" if subfolder else index_file + + # only hf is supported currently + if load_format == "auto": + load_format = "hf" + + # Some quantized models use .pt files for storing the weights. + if load_format == "hf": + allow_patterns = ["*.safetensors", "*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if subfolder is not None: + allow_patterns = [f"{subfolder}/{pattern}" for pattern in allow_patterns] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: list[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + # Decide by actual files rather than pattern name (patterns may include subfolders). + use_safetensors = any(f.endswith(".safetensors") for f in hf_weights_files) + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file_with_subfolder, + self.load_config.download_dir, + revision, + ) + # Some diffusers pipelines keep component weights under a + # subfolder (e.g. "transformer/") and the corresponding index file + # uses filenames relative to that subfolder. vLLM's + # `filter_duplicate_safetensors_files` expects weight_map entries + # to be relative to the `hf_folder` we pass in, so we point it to + # the component subfolder to avoid filtering out all shards. + filter_folder = os.path.join(hf_folder, subfolder) if subfolder is not None else hf_folder + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, + filter_folder, + index_file, + ) + else: + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator(self, source: "ComponentSource") -> Generator[tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, + source.subfolder, + source.revision, + source.fall_back_to_pt, + source.allow_patterns_overrides, + ) + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.safetensors_load_strategy, + ) + + if self.counter_before_loading_weights == 0.0: + self.counter_before_loading_weights = time.perf_counter() + # Apply the prefix. + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + + def get_all_weights( + self, + model: nn.Module, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + sources = cast( + Iterable[DiffusersPipelineLoader.ComponentSource], + getattr(model, "weights_sources", ()), + ) + for source in sources: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights( + model_name_or_path=model_config.model, + subfolder=None, + revision=model_config.revision, + fall_back_to_pt=True, + allow_patterns_overrides=None, + ) + + def load_model(self, od_config: OmniDiffusionConfig, load_device: str) -> nn.Module: + """Load a model with the given configurations.""" + target_device = torch.device(load_device) + with set_default_torch_dtype(od_config.dtype): + with target_device: + model = initialize_model(od_config) + + logger.debug("Loading weights on %s ...", load_device) + # Quantization does not happen in `load_weights` but after it + self.load_weights(model) + return model.eval() + + def load_weights(self, model: nn.Module) -> None: + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(self.get_all_weights(model)) + + self.counter_after_loading_weights = time.perf_counter() + logger.info_once( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - self.counter_before_loading_weights, + ) + # TODO(Isotr0py): Enable weights loading check after decoupling + # all components' weights loading (AutoModel.from_pretrained etc). + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if loaded_weights is not None: + _ = weights_to_load - loaded_weights + # if weights_not_loaded: + # raise ValueError( + # "Following weights were not initialized from " + # f"checkpoint: {weights_not_loaded}" + # ) diff --git a/vllm_omni/diffusion/models/__init__.py b/vllm_omni/diffusion/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7471a2da72f53e3928893f4c35bf19a1e80748 --- /dev/null +++ b/vllm_omni/diffusion/models/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Diffusion model implementations.""" diff --git a/vllm_omni/diffusion/models/bagel/__init__.py b/vllm_omni/diffusion/models/bagel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/diffusion/models/bagel/autoencoder.py b/vllm_omni/diffusion/models/bagel/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0980f25cd19e518a9b9d18615a856e6855db5e83 --- /dev/null +++ b/vllm_omni/diffusion/models/bagel/autoencoder.py @@ -0,0 +1,324 @@ +# Copyright (c) 2024 Black Forest Labs. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. +# +# Original file was released under Apache-2.0, with the full license text +# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE. +# +# This modified file is released under the same license. + +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + downsample: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee81a4fd47be4900df0e20cd0d183a5241e9d24 --- /dev/null +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -0,0 +1,1177 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team. +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. +# +# Original file was released under Apache-2.0, with the full license text +# available at https://github.com/huggingface/transformers/blob/main/LICENSE. + +import math +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +from torch import nn +from torch.nn.attention.flex_attention import flex_attention +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2MLP, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, +) +from transformers.utils import ModelOutput +from vllm.transformers_utils.configs.bagel import BagelConfig +from vllm.vllm_flash_attn import flash_attn_varlen_func + +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + + +def patchify(imgs, p): + """ + imgs: (N, 3, H, W) or (3, H, W) + x: (N, L, patch_size**2 *3) or (L, patch_size**2 *3) + """ + is_batch = imgs.ndim == 4 + if not is_batch: + imgs = imgs.unsqueeze(0) + + # n: batch, c: channel, h: grid_h, p: patch_h, w: grid_w, q: patch_w + x = imgs.reshape(imgs.shape[0], 3, imgs.shape[2] // p, p, imgs.shape[3] // p, p) + # Permute to (n, grid_h, grid_w, c, patch_h, patch_w) to match Conv2d (c, h, w) flattening + x = torch.einsum("nchpwq->nhwcpq", x) + x = x.reshape(imgs.shape[0], -1, 3 * p**2) + + if not is_batch: + x = x.squeeze(0) + return x + + +class MLPconnector(nn.Module): + def __init__(self, input_dim, output_dim, activation="gelu_pytorch_tanh"): + super().__init__() + self.fc1 = nn.Linear(input_dim, output_dim) + if activation == "gelu": + self.act = nn.GELU() + elif activation == "gelu_pytorch_tanh": + self.act = nn.GELU(approximate="tanh") + else: + self.act = nn.ReLU() + self.fc2 = nn.Linear(output_dim, output_dim) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +torch._dynamo.config.cache_size_limit = 512 +torch._dynamo.config.accumulated_cache_size_limit = 4096 +flex_attention = torch.compile(flex_attention) + + +class Qwen2MoTConfig(Qwen2Config): + """Configuration for Qwen2MoT (Mixture of Tokens) model. + + This is fundamentally different from Qwen2, hence the distinct name. + """ + + model_type = "qwen2_mot" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + is_causal=True, + _attn_implementation="eager", + qk_norm=True, + layer_module="Qwen2MoTDecoderLayer", + freeze_und=False, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + use_sliding_window=use_sliding_window, + sliding_window=sliding_window, + max_window_layers=max_window_layers, + attention_dropout=attention_dropout, + is_causal=is_causal, + _attn_implementation=_attn_implementation, + **kwargs, + ) + self.qk_norm = qk_norm + self.layer_module = layer_module + + +class NaiveCache: + def __init__(self, num_layers): + self.key_cache = {k: None for k in range(num_layers)} + self.value_cache = {k: None for k in range(num_layers)} + + @property + def num_layers(self): + return len(self.key_cache) + + @property + def seq_lens(self): + if self.key_cache[0] is not None: + return self.key_cache[0].shape[0] + else: + return 0 + + +@dataclass +class BaseNavitOutputWithPast(ModelOutput): + packed_query_sequence: torch.FloatTensor = None + past_key_values: NaiveCache | None = None + + +class PackedAttentionMoT(Qwen2Attention): + def __init__(self, config, layer_idx: int | None = None): + super().__init__(config, layer_idx) + self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + + head_dim = self.head_dim + self.q_proj_moe_gen = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) + self.v_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) + self.o_proj_moe_gen = nn.Linear(config.num_attention_heads * head_dim, config.hidden_size, bias=False) + + self.rotary_op = RotaryEmbedding(is_neox_style=True) + + def forward( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ): + if mode == "und": + packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) + packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_query_states = self.q_norm(packed_query_states) + packed_key_states = self.k_norm(packed_key_states) + elif mode == "gen": + packed_query_sequence = packed_query_sequence.to(torch.bfloat16) + packed_query_states = packed_query_sequence.new_zeros( + (packed_query_sequence.shape[0], self.num_heads * self.head_dim) + ) + packed_key_states = packed_query_sequence.new_zeros( + (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) + packed_value_states = packed_query_sequence.new_zeros( + (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) + + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + + packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) + packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) + + packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) + packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) + + packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) + packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) + + packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) + packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + + packed_query_states = packed_query_states.to(torch.float32) + packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) + packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen( + packed_query_states[packed_vae_token_indexes] + ) + + packed_key_states = packed_key_states.to(torch.float32) + packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) + packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen( + packed_key_states[packed_vae_token_indexes] + ) + + cos, sin = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] + packed_query_states = self.rotary_op(packed_query_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + packed_key_states = self.rotary_op(packed_key_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + + packed_query_states = packed_query_states.to(torch.bfloat16) + packed_key_states = packed_key_states.to(torch.bfloat16) + packed_value_states = packed_value_states.to(torch.bfloat16) + + if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + past_key_states = past_key_values.key_cache[self.layer_idx] + past_value_states = past_key_values.value_cache[self.layer_idx] + + seqlens = sum(query_lens) + sum(key_values_lens) + merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_key_states[packed_query_indexes] = packed_key_states + merged_key_states[packed_key_value_indexes] = past_key_states + merged_value_states[packed_query_indexes] = packed_value_states + merged_value_states[packed_key_value_indexes] = past_value_states + key_values_lens = key_values_lens + query_lens + else: + merged_key_states = packed_key_states + merged_value_states = packed_value_states + key_values_lens = query_lens + + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) + + packed_attn_output = flash_attn_varlen_func( + q=packed_query_states, + k=merged_key_states, + v=merged_value_states, + cu_seqlens_q=cu_seqlens_q.to(torch.int32), + cu_seqlens_k=cu_seqlens_k.to(torch.int32), + max_seqlen_q=max(query_lens).item(), + max_seqlen_k=max(key_values_lens).item(), + causal=is_causal, + ) + packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) + if mode == "und": + packed_attn_output = self.o_proj(packed_attn_output) + elif mode == "gen": + packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) + packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen( + packed_attn_output[packed_vae_token_indexes] + ) + + if update_past_key_values: + past_key_values.key_cache[self.layer_idx] = merged_key_states + past_key_values.value_cache[self.layer_idx] = merged_value_states + + return packed_attn_output, past_key_values + + +class Qwen2MoTDecoderLayer(nn.Module): + def __init__( + self, + config, + layer_idx: int | None = None, + attn_module: Qwen2Attention | None = PackedAttentionMoT, + ): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = attn_module(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.mlp_moe_gen = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + packed_query_sequence: torch.Tensor | None = None, + query_lens: torch.Tensor = None, + packed_query_position_embeddings: torch.Tensor = None, + packed_query_indexes: torch.Tensor = None, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + if packed_query_sequence is None: + packed_query_sequence = hidden_states + residual = packed_query_sequence + if mode == "und": + packed_query_sequence = self.input_layernorm(packed_query_sequence) + elif mode == "gen": + packed_query_sequence_ = torch.zeros_like(packed_query_sequence) + packed_query_sequence_[packed_text_indexes] = self.input_layernorm( + packed_query_sequence[packed_text_indexes] + ) + packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen( + packed_query_sequence[packed_vae_token_indexes] + ) + packed_query_sequence = packed_query_sequence_ + + # Self Attention + packed_query_sequence, past_key_values = self.self_attn( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + mode=mode, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + packed_query_sequence = residual + packed_query_sequence + + # Fully Connected + residual = packed_query_sequence + if mode == "und": + packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) + packed_query_sequence = self.mlp(packed_query_sequence) + elif mode == "gen": + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16) + packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to( + torch.bfloat16 + ) + + packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) + packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) + packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) + packed_query_sequence = packed_query_sequence_ + + packed_query_sequence = residual + packed_query_sequence + + return packed_query_sequence, past_key_values + + +class Qwen2MoTModel(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.use_moe = "Mo" in config.layer_module + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.use_moe: + self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_ids: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + # create position embeddings to be shared across the decoder layers + cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0)) + cos = cos.squeeze(0) + sin = sin.squeeze(0) + packed_query_position_embeddings = (cos, sin) + + extra_inputs = {} + if self.use_moe: + extra_inputs.update(mode=mode) + if mode == "gen": + assert packed_vae_token_indexes is not None + assert packed_text_indexes is not None + extra_inputs.update( + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + + for layer_idx, decoder_layer in enumerate(self.layers): + packed_query_sequence, past_key_values = decoder_layer( + hidden_states=packed_query_sequence, + encoder_hidden_states=None, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + **extra_inputs, + ) + + if self.use_moe: + if mode == "und": + packed_query_sequence = self.norm(packed_query_sequence) + elif mode == "gen": + packed_query_sequence_ = torch.zeros_like(packed_query_sequence) + packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) + packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen( + packed_query_sequence[packed_vae_token_indexes] + ) + packed_query_sequence = packed_query_sequence_ + else: + packed_query_sequence = self.norm(packed_query_sequence) + + return BaseNavitOutputWithPast( + packed_query_sequence=packed_query_sequence, + past_key_values=past_key_values, + ) + + +class Qwen2MoTForCausalLM(Qwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2MoTModel(config) + self.vocab_size = config.vocab_size + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_ids: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + outputs = self.model( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_ids=packed_query_position_ids, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + mode=mode, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + + return outputs + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class PositionEmbedding(nn.Module): + def __init__(self, max_num_patch_per_side, hidden_size): + super().__init__() + self.max_num_patch_per_side = max_num_patch_per_side + self.hidden_size = hidden_size + self.pos_embed = nn.Parameter(torch.zeros(max_num_patch_per_side**2, hidden_size), requires_grad=False) + self._init_weights() + + def _init_weights(self): + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + + def forward(self, position_ids): + return self.pos_embed[position_ids] + + +def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): + num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size + coords_h = torch.arange(0, num_patches_h) + coords_w = torch.arange(0, num_patches_w) + pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() + return pos_ids + + +class Bagel(torch.nn.Module): + config_class = BagelConfig + base_model_prefix = "bagel" + + def __init__(self, language_model, vit_model, config: BagelConfig): + super().__init__() + self.language_model = language_model + self.hidden_size = config.llm_config.hidden_size + self.use_moe = "Mo" in config.llm_config.layer_module + self.num_heads = config.llm_config.num_attention_heads + + if config.visual_gen: + self.latent_patch_size = config.latent_patch_size + self.timestep_shift = config.timestep_shift + self.latent_downsample = config.vae_config.downsample * config.latent_patch_size + self.max_latent_size = config.max_latent_size + self.latent_channel = config.vae_config.z_channels + self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel + self.time_embedder = TimestepEmbedder(self.hidden_size) + self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) + self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) + self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) + + if config.visual_und: + self.vit_model = vit_model + self.vit_patch_size = config.vit_config.patch_size + self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side + self.vit_hidden_size = config.vit_config.hidden_size + self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act) + self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size) + + self.get_flattened_position_ids = get_flattened_position_ids_extrapolate + + self.config = config + self._init_weights() + + def _init_weights(self): + if self.config.visual_gen: + nn.init.constant_(self.llm2vae.weight, 0) + nn.init.constant_(self.llm2vae.bias, 0) + + def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids): + packed_text_ids = list() + packed_text_position_ids = list() + text_token_lens = list() + packed_text_indexes = list() + packed_key_value_indexes = list() + + curr = 0 + newlens, new_rope = list(), list() + for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + text_ids = tokenizer.encode(prompt) + text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]] + text_token_lens.append(len(text_ids)) + packed_text_ids.extend(text_ids) + packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids))) + packed_text_indexes.extend(range(curr, curr + len(text_ids))) + newlens.append(curr_kvlen + len(text_ids)) + new_rope.append(curr_position_id + len(text_ids)) + curr += len(text_ids) + + generation_input = { + "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int), + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + } + + return generation_input, newlens, new_rope + + def forward_cache_update_text( + self, + past_key_values: NaiveCache, + packed_text_ids: torch.IntTensor, + packed_text_position_ids: torch.LongTensor, + text_token_lens: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + + extra_inputs = {} + if self.use_moe: + extra_inputs = {"mode": "und"} + + output = self.language_model.forward( + packed_query_sequence=packed_text_embedding, + query_lens=text_token_lens, + packed_query_position_ids=packed_text_position_ids, + packed_query_indexes=packed_text_indexes, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + key_values_lens=key_values_lens, + update_past_key_values=True, + is_causal=True, + **extra_inputs, + ) + past_key_values = output.past_key_values + + return past_key_values + + def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0): + patchified_vae_latent_shapes, packed_vae_position_ids = list(), list() + packed_vae_token_indexes = list() + packed_text_ids, packed_text_indexes = list(), list() + packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() + packed_key_value_indexes = list() + + _curr = curr = 0 + vae_image_tensors = list() + newlens, new_rope = list(), list() + for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_text_ids.append(new_token_ids["start_of_image"]) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + image_tensor = transforms(image) + vae_image_tensors.append(image_tensor) + vae_position_ids = self.get_flattened_position_ids( + image_tensor.size(1), + image_tensor.size(2), + self.latent_downsample, + max_num_patches_per_side=self.max_latent_size, + ) + packed_vae_position_ids.append(vae_position_ids) + H, W = image_tensor.shape[1:] + h = H // self.latent_downsample + w = W // self.latent_downsample + patchified_vae_latent_shapes.append((h, w)) + + num_img_tokens = w * h + packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens)) + packed_indexes.extend(range(curr, curr + num_img_tokens)) + curr += num_img_tokens + _curr += num_img_tokens + + packed_text_ids.append(new_token_ids["end_of_image"]) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) + packed_seqlens.append(num_img_tokens + 2) + newlens.append(curr_kvlen + num_img_tokens + 2) + new_rope.append(curr_position_id + 1) + + image_sizes = [item.shape for item in vae_image_tensors] + max_image_size = [max(item) for item in list(zip(*image_sizes))] + padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size)) + for i, image_tensor in enumerate(vae_image_tensors): + padded_images[i, :, : image_tensor.shape[1], : image_tensor.shape[2]] = image_tensor + + generation_input = { + "padded_images": padded_images, + "patchified_vae_latent_shapes": patchified_vae_latent_shapes, + "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), + "packed_timesteps": torch.tensor([timestep]), + "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + } + + return generation_input, newlens, new_rope + + def forward_cache_update_vae( + self, + vae_model, + past_key_values: NaiveCache, + padded_images: torch.Tensor, + patchified_vae_latent_shapes: list, + packed_vae_position_ids: torch.LongTensor, + packed_timesteps: torch.Tensor, + packed_vae_token_indexes: torch.LongTensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + packed_key_value_indexes: torch.Tensor, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + padded_latent = vae_model.encode(padded_images) + + p = self.latent_patch_size + packed_latent = list() + for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): + latent = latent[:, : h * p, : w * p].reshape(self.latent_channel, h, p, w, p) + latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) + packed_latent.append(latent) + packed_latent = torch.cat(packed_latent, dim=0) + packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) + packed_timestep_embeds = self.time_embedder(packed_timesteps) + packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed + if packed_latent.dtype != packed_sequence.dtype: + packed_latent = packed_latent.to(packed_sequence.dtype) + packed_sequence[packed_vae_token_indexes] = packed_latent + + extra_inputs = {} + if self.use_moe: + extra_inputs = { + "mode": "gen", + "packed_vae_token_indexes": packed_vae_token_indexes, + "packed_text_indexes": packed_text_indexes, + } + + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=True, + is_causal=False, + **extra_inputs, + ) + past_key_values = output.past_key_values + + return past_key_values + + def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids): + packed_vit_token_indexes = list() + vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list() + packed_text_ids, packed_text_indexes = list(), list() + packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() + packed_key_value_indexes = list() + + _curr = curr = 0 + newlens, new_rope = list(), list() + for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_text_ids.append(new_token_ids["start_of_image"]) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + image_tensor = transforms(image) + vit_position_ids = self.get_flattened_position_ids( + image_tensor.size(1), + image_tensor.size(2), + self.vit_patch_size, + max_num_patches_per_side=self.vit_max_num_patch_per_side, + ) + vit_tokens = patchify(image_tensor, self.vit_patch_size) + packed_vit_tokens.append(vit_tokens) + num_img_tokens = vit_tokens.shape[0] + packed_vit_position_ids.append(vit_position_ids) + vit_token_seqlens.append(num_img_tokens) + packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens)) + packed_indexes.extend(range(curr, curr + num_img_tokens)) + curr += num_img_tokens + _curr += num_img_tokens + + packed_text_ids.append(new_token_ids["end_of_image"]) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) + packed_seqlens.append(num_img_tokens + 2) + newlens.append(curr_kvlen + num_img_tokens + 2) + new_rope.append(curr_position_id + 1) + + generation_input = { + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int), + "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0), + "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0), + "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long), + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + } + + return generation_input, newlens, new_rope + + def forward_cache_update_vit( + self, + past_key_values: NaiveCache, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_vit_tokens: torch.Tensor, + packed_vit_token_indexes: torch.LongTensor, + packed_vit_position_ids: torch.LongTensor, + vit_token_seqlens: torch.IntTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_indexes: torch.LongTensor, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) + cu_seqlens = cu_seqlens.to(torch.int32) + max_seqlen = torch.max(vit_token_seqlens).item() + packed_vit_token_embed = self.vit_model( + packed_pixel_values=packed_vit_tokens, + packed_flattened_position_ids=packed_vit_position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + packed_vit_token_embed = self.connector(packed_vit_token_embed) + pos_emb = self.vit_pos_embed(packed_vit_position_ids) + packed_vit_token_embed = packed_vit_token_embed + pos_emb + if packed_vit_token_embed.dtype != packed_sequence.dtype: + packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype) + packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed + + extra_inputs = {} + if self.use_moe: + extra_inputs = {"mode": "und"} + + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + key_values_lens=key_values_lens, + update_past_key_values=True, + is_causal=False, + **extra_inputs, + ) + past_key_values = output.past_key_values + + return past_key_values + + def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None): + packed_text_ids, packed_text_indexes = list(), list() + packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() + packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() + packed_key_value_indexes = list() + + query_curr = curr = 0 + for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_text_ids.append(new_token_ids["start_of_image"]) + packed_text_indexes.append(query_curr) + + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + vae_position_ids = self.get_flattened_position_ids( + H, W, self.latent_downsample, max_num_patches_per_side=self.max_latent_size + ) + packed_vae_position_ids.append(vae_position_ids) + + h, w = H // self.latent_downsample, W // self.latent_downsample + num_image_tokens = h * w + + packed_init_noises.append(torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size**2)) + packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) + packed_seqlens.append(num_image_tokens + 2) + + packed_indexes.extend(range(curr, curr + num_image_tokens)) + curr += num_image_tokens + query_curr += num_image_tokens + + packed_text_ids.append(new_token_ids["end_of_image"]) + packed_text_indexes.append(query_curr) + + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) + + # Construct Output + generation_input = { + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_init_noises": torch.cat(packed_init_noises, dim=0), + "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), + "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + } + + return generation_input + + def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids): + return self.prepare_input(curr_kvlens, curr_rope, image_sizes, new_token_ids) + + def generate_image( + self, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_init_noises: torch.Tensor, + packed_vae_position_ids: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_position_ids: torch.LongTensor, + packed_indexes: torch.LongTensor, + past_key_values: NaiveCache, + key_values_lens: torch.IntTensor, + packed_key_value_indexes: torch.LongTensor, + num_timesteps: int = 24, + timestep_shift: float = 1.0, + ): + model_pred_cache_dic, model_pred_current = None, None + model_pred_text_cache_dic, model_pred_text_current = None, None + model_pred_img_cache_dic, model_pred_img_current = None, None + + x_t = packed_init_noises + + timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) + timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) + dts = timesteps[:-1] - timesteps[1:] + timesteps = timesteps[:-1] + + for i, t in enumerate(timesteps): + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + v_t = self._forward_flow( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_position_ids=packed_position_ids, + packed_indexes=packed_indexes, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + # cache + model_pred_cache_dic=model_pred_cache_dic, + model_pred_current=model_pred_current, + model_pred_text_cache_dic=model_pred_text_cache_dic, + model_pred_text_current=model_pred_text_current, + model_pred_img_cache_dic=model_pred_img_cache_dic, + model_pred_img_current=model_pred_img_current, + ) + + x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + + def _forward_flow( + self, + x_t: torch.Tensor, + timestep: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_vae_position_ids: torch.LongTensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_indexes: torch.LongTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + key_values_lens: torch.IntTensor, + past_key_values: NaiveCache, + packed_key_value_indexes: torch.LongTensor, + # cache + model_pred_cache_dic: dict[str, Any] | None = None, + model_pred_current: int | None = None, + model_pred_text_cache_dic: dict[str, Any] | None = None, + model_pred_text_current: int | None = None, + model_pred_img_cache_dic: dict[str, Any] | None = None, + model_pred_img_current: int | None = None, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + assert timestep.unique().shape[0] == 1 + packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) + packed_timestep_embeds = self.time_embedder(timestep) + x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed + if x_t.dtype != packed_sequence.dtype: + x_t = x_t.to(packed_sequence.dtype) + packed_sequence[packed_vae_token_indexes] = x_t + + extra_inputs = {} + if self.use_moe: + extra_inputs = { + "mode": "gen", + "packed_vae_token_indexes": packed_vae_token_indexes, + "packed_text_indexes": packed_text_indexes, + } + + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + v_t = self.llm2vae(output.packed_query_sequence) + v_t = v_t[packed_vae_token_indexes] + + return v_t diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb9f1f5c3f451a35eff40c5fc15929aa2b123a5 --- /dev/null +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -0,0 +1,584 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +BagelPipeline implementation for vLLM-Omni. +""" + +from __future__ import annotations + +import json +import os +from collections.abc import Iterable +from dataclasses import dataclass +from math import isqrt + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch import nn +from transformers import AutoTokenizer, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.transformers_utils.configs.bagel import BagelConfig + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +from .autoencoder import AutoEncoder, AutoEncoderParams +from .bagel_transformer import Bagel, NaiveCache, Qwen2MoTConfig, Qwen2MoTForCausalLM + +logger = init_logger(__name__) + + +@dataclass +class BagelGenParams: + num_timesteps: int = 50 + timestep_shift: float = 1.0 + + +def add_special_tokens(tokenizer): + all_special_tokens = [] + for k, v in tokenizer.special_tokens_map.items(): + if isinstance(v, str): + all_special_tokens.append(v) + elif isinstance(v, list): + all_special_tokens += v + + new_tokens = [] + + if "<|im_start|>" not in all_special_tokens: + new_tokens.append("<|im_start|>") + + if "<|im_end|>" not in all_special_tokens: + new_tokens.append("<|im_end|>") + + if "<|vision_start|>" not in all_special_tokens: + new_tokens.append("<|vision_start|>") + + if "<|vision_end|>" not in all_special_tokens: + new_tokens.append("<|vision_end|>") + + num_new_tokens = tokenizer.add_tokens(new_tokens) + bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + start_of_image = tokenizer.convert_tokens_to_ids("<|vision_start|>") + end_of_image = tokenizer.convert_tokens_to_ids("<|vision_end|>") + + new_token_ids = dict( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + start_of_image=start_of_image, + end_of_image=end_of_image, + ) + + return tokenizer, new_token_ids, num_new_tokens + + +def get_bagel_post_process_func(od_config: OmniDiffusionConfig): + # BagelPipeline returns PIL.Image.Image directly. + def post_process_func(x): + return x + + return post_process_func + + +@dataclass +class _VaeCfg: + z_channels: int = 16 + downsample: int = 8 + + +@dataclass +class _VitCfg: + patch_size: int = 14 + hidden_size: int = 1152 + + +def default_ae_params() -> AutoEncoderParams: + return AutoEncoderParams( + resolution=256, + in_channels=3, + downsample=8, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) + + +class SiglipNaViTWrapper(nn.Module): + def __init__(self, vision_model): + super().__init__() + # If input is SiglipVisionModel, unwrap it to get SiglipVisionTransformer + if hasattr(vision_model, "vision_model"): + self.vision_model = vision_model.vision_model + else: + self.vision_model = vision_model + + # Configure weights for linear equivalent of patch embedding + self.patch_embed_weight = self.vision_model.embeddings.patch_embedding.weight + self.patch_embed_bias = self.vision_model.embeddings.patch_embedding.bias + + def forward(self, packed_pixel_values, packed_flattened_position_ids, cu_seqlens, max_seqlen): + w = self.patch_embed_weight.view(self.patch_embed_weight.shape[0], -1) + x = F.linear(packed_pixel_values, w, self.patch_embed_bias) + pos = self.vision_model.embeddings.position_embedding(packed_flattened_position_ids) + x = x + pos + hidden_states = x.unsqueeze(0) + seq_len = x.shape[0] + mask = torch.full((1, 1, seq_len, seq_len), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype) + cu_seqlens_list = cu_seqlens.tolist() + for i in range(len(cu_seqlens_list) - 1): + start = cu_seqlens_list[i] + end = cu_seqlens_list[i + 1] + mask[..., start:end, start:end] = 0.0 + + outputs = self.vision_model.encoder(inputs_embeds=hidden_states, attention_mask=mask) + return outputs.last_hidden_state.squeeze(0) + + +class BagelPipeline(nn.Module): + """Bagel generation pipeline (MoT) packaged for vllm-omni diffusion engine. + + This pipeline is self-contained and uses the ported Bagel core files. + """ + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__() + self.od_config = od_config + self.device = get_local_device() + + model = od_config.model + local_files_only = os.path.exists(model) + if local_files_only: + model_path = model + else: + # Download everything required (ema.safetensors, ae.safetensors, tokenizer files, configs). + model_path = download_weights_from_hf_specific(model, od_config.revision, ["*"]) + + # Load Bagel top-level config for VAE settings. + cfg_path = os.path.join(model_path, "config.json") + with open(cfg_path, encoding="utf-8") as f: + bagel_cfg = json.load(f) + + vae_cfg_dict = bagel_cfg.get("vae_config") or {} + vae_cfg = _VaeCfg( + z_channels=int(vae_cfg_dict.get("z_channels", 16)), + downsample=int(vae_cfg_dict.get("downsample", 8)), + ) + + # LLM config: Bagel MoT requires explicitly setting layer_module + llm_cfg_path = os.path.join(model_path, "llm_config.json") + llm_config = Qwen2MoTConfig.from_json_file(llm_cfg_path) + llm_config.qk_norm = True + llm_config.tie_word_embeddings = False + # Allow overriding from vllm-omni config if user wants MoE/vanilla. + llm_config.layer_module = od_config.override_transformer_cls_name or "Qwen2MoTDecoderLayer" + + # Tokenizer and special tokens. + # Bagel uses a Qwen2 tokenizer variant; prefer trust_remote_code to get the + # correct tokenizer implementation from the checkpoint repo when available. + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + local_files_only=True, + trust_remote_code=True, + ) + + # Try finding vision_config or interpolate from top-level config + vit_cfg_dict = bagel_cfg.get("vit_config") or {} + vit_cfg = _VitCfg( + patch_size=int(vit_cfg_dict.get("patch_size", 14)), + hidden_size=int(vit_cfg_dict.get("hidden_size", 1152)), + ) + vit_config_path = os.path.join(model_path, "vit_config.json") + vit_conf = SiglipVisionConfig.from_json_file(vit_config_path) + self.vit_model = SiglipVisionModel(vit_conf) + self.image_processor = SiglipImageProcessor.from_pretrained(model_path, local_files_only=True) + + if self.vit_model: + self.vit_model = SiglipNaViTWrapper(self.vit_model) + vit_cfg.hidden_size = self.vit_model.vision_model.config.hidden_size + vit_cfg.patch_size = self.vit_model.vision_model.config.patch_size + + self.tokenizer, self.new_token_ids, _ = add_special_tokens(self.tokenizer) + + tok_len = len(self.tokenizer) + required_max_id = max(int(v) for v in self.new_token_ids.values()) + llm_config.vocab_size = max( + int(getattr(llm_config, "vocab_size", tok_len)), + int(tok_len), + int(required_max_id + 1), + ) + + self.language_model = Qwen2MoTForCausalLM(llm_config) + ae_params: AutoEncoderParams = default_ae_params() + self.vae = AutoEncoder(ae_params) + + self.bagel = Bagel( + language_model=self.language_model, + vit_model=self.vit_model, + config=BagelConfig( + llm_config=llm_config, + vae_config=vae_cfg, + vit_config=vit_cfg, + vit_max_num_patch_per_side=int(bagel_cfg.get("vit_max_num_patch_per_side", 70)), + connector_act=str(bagel_cfg.get("connector_act", "gelu_pytorch_tanh")), + interpolate_pos=bool(bagel_cfg.get("interpolate_pos", False)), + latent_patch_size=int(bagel_cfg.get("latent_patch_size", 2)), + max_latent_size=int(bagel_cfg.get("max_latent_size", 32)), + timestep_shift=float(bagel_cfg.get("timestep_shift", 1.0)), + ), + ) + + # Let vLLM loader download and stream all *.safetensors under model root. + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder=None, + revision=od_config.revision, + prefix="", + fall_back_to_pt=False, + ) + ] + + self.to(self.device) + + @staticmethod + def _decode_image_from_latent( + bagel: Bagel, vae: AutoEncoder, latent: torch.Tensor, image_shape: tuple[int, int] + ) -> Image.Image: + H, W = image_shape + h, w = H // bagel.latent_downsample, W // bagel.latent_downsample + p = bagel.latent_patch_size + c = bagel.latent_channel + latent = latent.reshape(1, h, w, p, p, c) + latent = torch.einsum("nhwpqc->nchpwq", latent) + latent = latent.reshape(1, c, h * p, w * p) + + # Cast to VAE dtype (e.g. bfloat16) as latents might remain float32 from generation loop + vae_dtype = next(vae.parameters()).dtype + latent = latent.to(vae_dtype) + + image = vae.decode(latent) + image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255 + return Image.fromarray(image.to(torch.uint8).cpu().numpy()) + + @torch.inference_mode() + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + # TODO: In online mode, sometimes it receives [{"prompts": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(req.prompts[0], str) else (req.prompts[0].get("prompt") or "") + + max_hw = int(self.bagel.max_latent_size * self.bagel.latent_downsample) + if req.sampling_params.height is None and req.sampling_params.width is None: + height = width = max_hw + else: + height = int(req.sampling_params.height) if req.sampling_params.height is not None else max_hw + width = int(req.sampling_params.width) if req.sampling_params.width is not None else max_hw + if height > max_hw or width > max_hw: + raise ValueError( + f"Requested resolution {height}x{width} exceeds Bagel checkpoint limit " + f"{max_hw}x{max_hw} (max_latent_size={self.bagel.max_latent_size}, " + f"latent_downsample={self.bagel.latent_downsample})." + ) + image_shape = (height, width) + + # Map request params to Bagel gen params (defaults follow Bagel inferencer) + gen_params = BagelGenParams( + num_timesteps=int(req.sampling_params.num_inference_steps or 50), + timestep_shift=3.0, + ) + + gen_context = { + "kv_lens": [0], + "ropes": [0], + "past_key_values": NaiveCache(self.bagel.config.llm_config.num_hidden_layers), + } + + # Add text prompt (prefill) on gen context. + # [Omni] Check for injected KV Cache from remote transfer + injected_kv = req.sampling_params.past_key_values + if injected_kv is not None: + logger.info("Using injected KV Cache (direct)") + gen_context["past_key_values"] = injected_kv + + # User requested: kv_lens and ropes set to [gen_context["past_key_values"].key_cache[0].shape[0]] + # Assuming injected_kv is compatible and has key_cache[0] + seq_len = injected_kv.key_cache[0].shape[0] + gen_context["kv_lens"] = [seq_len] + gen_context["ropes"] = [seq_len] + + else: + image_input = ( + None if isinstance(first_prompt, str) else (first_prompt.get("multi_modal_data") or {}).get("image") + ) + if image_input and not isinstance(image_input, list): + image_input = [image_input] + if image_input: + image_input = [Image.open(image) if isinstance(image, str) else image for image in image_input] + + if image_input: + # If we have an image, we prefill with it + if self.image_processor and self.vae: + + def vit_transforms(img): + # SigLIP processor returns dict with pixel_values; we want the tensor + return self.image_processor(images=img, return_tensors="pt").pixel_values[0] + + def vae_transforms(img): + if img.mode != "RGB": + img = img.convert("RGB") + # Convert to [-1, 1] tensor (H, W, C) -> (C, H, W) + arr = torch.from_numpy(np.array(img)).float() / 127.5 - 1.0 + return arr.permute(2, 0, 1) + + # 1. Update VAE + gen_input_vae, newlens_vae, new_rope_vae = self.bagel.prepare_vae_images( + curr_kvlens=gen_context["kv_lens"], + curr_rope=gen_context["ropes"], + images=image_input, + transforms=vae_transforms, + new_token_ids=self.new_token_ids, + ) + + for k, v in gen_input_vae.items(): + if torch.is_tensor(v): + gen_input_vae[k] = v.to(self.device) + + # VAE needs bfloat16 to match model strings usually, specifically encode + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + gen_context["past_key_values"] = self.bagel.forward_cache_update_vae( + self.vae, gen_context["past_key_values"], **gen_input_vae + ) + gen_context["kv_lens"] = newlens_vae + gen_context["ropes"] = new_rope_vae + + # 2. Update ViT + gen_input_img, newlens_img, new_rope_img = self.bagel.prepare_vit_images( + curr_kvlens=gen_context["kv_lens"], + curr_rope=gen_context["ropes"], + images=image_input, + transforms=vit_transforms, + new_token_ids=self.new_token_ids, + ) + + for k, v in gen_input_img.items(): + if torch.is_tensor(v): + gen_input_img[k] = v.to(self.device) + + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + gen_context["past_key_values"] = self.bagel.forward_cache_update_vit( + gen_context["past_key_values"], **gen_input_img + ) + gen_context["kv_lens"] = newlens_img + gen_context["ropes"] = new_rope_img + generation_input, newlens, new_rope = self.bagel.prepare_prompts( + curr_kvlens=gen_context["kv_lens"], + curr_rope=gen_context["ropes"], + prompts=[prompt], + tokenizer=self.tokenizer, + new_token_ids=self.new_token_ids, + ) + # Fail fast with a clear error instead of CUDA gather OOB. + max_tid = int(generation_input["packed_text_ids"].max().item()) + emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + if max_tid >= emb_n: + raise ValueError( + "Tokenizer/model vocab mismatch: max token id " + f"{max_tid} >= embed_tokens size {emb_n}. " + "This usually means you're not using the tokenizer shipped with the Bagel checkpoint, " + "or llm_config.vocab_size is smaller than the tokenizer vocab." + ) + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(self.device) + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + gen_context["past_key_values"] = self.bagel.forward_cache_update_text( + gen_context["past_key_values"], **generation_input + ) + gen_context["kv_lens"] = newlens + gen_context["ropes"] = new_rope + + if req.sampling_params.seed is not None: + torch.manual_seed(req.sampling_params.seed) + if self.device.type == "cuda": + torch.cuda.manual_seed(req.sampling_params.seed) + + # Prepare latent query and run flow + generation_input = self.bagel.prepare_vae_latent( + curr_kvlens=gen_context["kv_lens"], + curr_rope=gen_context["ropes"], + image_sizes=[image_shape], + new_token_ids=self.new_token_ids, + ) + # Fail fast for special tokens used by the image path as well. + max_tid_img = int(generation_input["packed_text_ids"].max().item()) + emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + if max_tid_img >= emb_n: + raise ValueError( + "Tokenizer/model vocab mismatch (image path): max token id " + f"{max_tid_img} >= embed_tokens size {emb_n}. " + "This indicates the tokenizer token IDs do not match the checkpoint embeddings." + ) + # Position ids must be non-negative; negative ids can trigger CUDA gather OOB inside RoPE. + min_pid = int(generation_input["packed_position_ids"].min().item()) + if min_pid < 0: + raise ValueError(f"Invalid packed_position_ids: min={min_pid} (must be >= 0)") + # Latent position embedding bounds check: ids must be < max_latent_size^2. + max_lat_pid = int(generation_input["packed_vae_position_ids"].max().item()) + max_lat_pid_allowed = int(self.bagel.max_latent_size * self.bagel.max_latent_size) - 1 + if max_lat_pid > max_lat_pid_allowed: + raise ValueError( + "Invalid packed_vae_position_ids (latent position embedding OOB): " + f"max={max_lat_pid} > allowed_max={max_lat_pid_allowed}. " + f"Requested image_shape={image_shape}, max_latent_size={self.bagel.max_latent_size}." + ) + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(self.device) + + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + latents = self.bagel.generate_image( + past_key_values=gen_context["past_key_values"], + num_timesteps=gen_params.num_timesteps, + timestep_shift=gen_params.timestep_shift, + **generation_input, + ) + + # Decode first sample + img = self._decode_image_from_latent(self.bagel, self.vae, latents[0], image_shape) + return DiffusionOutput(output=img) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + state = self.state_dict() + allowed = set(state.keys()) + shapes = {k: tuple(v.shape) for k, v in state.items()} + + def _normalize_name(name: str) -> str: + # Common wrappers/prefixes in checkpoints. + for pfx in ("module.", "model."): + if name.startswith(pfx): + name = name[len(pfx) :] + # Common component renames across repos. + if name.startswith("vae_model."): + name = "vae." + name[len("vae_model.") :] + # Bagel `ae.safetensors` commonly stores AE weights without a top-level prefix. + # Map them into this pipeline's `vae.*` namespace. + if name.startswith("encoder.") or name.startswith("decoder."): + name = "vae." + name + return name + + def _iter_candidate_names(name: str) -> Iterable[str]: + """Yield candidate parameter names in this pipeline for a checkpoint key. + + The upstream Bagel repo typically stores Bagel-core layers (time_embedder, + latent_pos_embed, vae2llm, llm2vae, etc.) at the top-level of the model, + while this vllm-omni integration nests them under `self.bagel`. + """ + n = _normalize_name(name) + yield n + + # Map Bagel core layers from top-level -> `bagel.*` namespace. + for pfx in ("time_embedder.", "latent_pos_embed.", "vae2llm.", "llm2vae."): + if n.startswith(pfx): + yield "bagel." + n + break + + # Map connector and vit_pos_embed to `bagel.*` + for pfx in ("connector.", "vit_pos_embed."): + if n.startswith(pfx): + yield "bagel." + n + break + + if n.startswith("vit_model."): + yield "bagel." + n # matches self.bagel.vit_model + elif n.startswith("vision_model."): + yield "bagel.vit_model." + n + elif n.startswith("model.vision_model."): + yield "bagel.vit_model." + n[len("model.") :] + + def _filtered_weights(): + total = 0 + kept = 0 + shape_mismatch = 0 + for name, tensor in weights: + total += 1 + picked = None + for cand in _iter_candidate_names(name): + if cand in allowed: + # Only accept if tensor shape matches target param/buffer shape. + if tuple(tensor.shape) == shapes.get(cand): + picked = cand + break + else: + if cand.endswith("bagel.latent_pos_embed.pos_embed") and tensor.ndim == 2: + npos, hdim = tensor.shape + side = isqrt(int(npos)) + if side * side == int(npos) and hdim == int(self.bagel.hidden_size): + param = self.bagel.latent_pos_embed.pos_embed + # Resize in-place to keep the same Parameter object. + param.data = param.data.new_empty((npos, hdim)) + # Update model bookkeeping so position-id generation matches. + self.bagel.max_latent_size = int(side) + if hasattr(self.bagel, "config"): + setattr(self.bagel.config, "max_latent_size", int(side)) + if hasattr(self.bagel.latent_pos_embed, "max_num_patch_per_side"): + self.bagel.latent_pos_embed.max_num_patch_per_side = int(side) + shapes[cand] = (npos, hdim) + picked = cand + break + # Handle flattened patch embedding for SigLIP + if cand.endswith("embeddings.patch_embedding.weight") and tensor.ndim == 2: + # Checkpoint has (Hidden, C*P*P), model expects (Hidden, C, P, P) + if shapes.get(cand) is not None: + target_shape = shapes[cand] + if tensor.numel() == torch.prod(torch.tensor(target_shape)): + # Reshape tensor to match target + tensor = tensor.view(target_shape) + picked = cand + break + + shape_mismatch += 1 + # Keep this quiet; shape mismatches are expected for ignored modules. + if picked is not None: + kept += 1 + yield picked, tensor + # else: ignore extra weights (e.g. connector/vision/und) + logger.info_once( + "BagelPipeline weight filter kept %d/%d tensors (shape mismatches seen: %d)", + kept, + total, + shape_mismatch, + ) + + loader = AutoWeightsLoader(self) + return loader.load_weights(_filtered_weights()) diff --git a/vllm_omni/diffusion/models/flux/__init__.py b/vllm_omni/diffusion/models/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cad3508a104026850ed41b068c429c25118320c2 --- /dev/null +++ b/vllm_omni/diffusion/models/flux/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FLUX.1-dev diffusion model components.""" + +from vllm_omni.diffusion.models.flux.flux_transformer import ( + FluxTransformer2DModel, +) +from vllm_omni.diffusion.models.flux.pipeline_flux import ( + FluxPipeline, + get_flux_post_process_func, +) + +__all__ = [ + "FluxPipeline", + "FluxTransformer2DModel", + "get_flux_post_process_func", +] diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..faf6d08d3a508f6df7da2f74e8d68ef343d604fa --- /dev/null +++ b/vllm_omni/diffusion/models/flux/flux_transformer.py @@ -0,0 +1,638 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn.functional as F +from diffusers.models.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + get_1d_rotary_pos_embed, +) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from diffusers.utils import is_torch_npu_available +from torch import nn +from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +logger = init_logger(__name__) + + +class ColumnParallelApproxGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + ) -> None: + super().__init__() + + assert activation_fn == "gelu-approximate", "Only gelu-approximate is supported." + + inner_dim = inner_dim or int(dim * mult) + dim_out = dim_out or dim + + layers: list[nn.Module] = [ + ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + nn.Identity(), # placeholder for weight loading + RowParallelLinear( + inner_dim, + dim_out, + input_is_parallel=True, + return_bias=False, + ), + ] + + self.net = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class FluxAttention(torch.nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: bool | None = None, + pre_only: bool = False, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=bias, + ) + + if not self.pre_only: + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + self.inner_dim, + self.out_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ), + nn.Dropout(dropout), + ] + ) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + + self.add_kv_proj = QKVParallelLinear( + hidden_size=self.added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=added_proj_bias, + ) + + self.to_add_out = RowParallelLinear( + self.inner_dim, + query_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.to_qkv.num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.to_qkv.num_kv_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + qkv, _ = self.to_qkv(hidden_states) + q_size = self.to_qkv.num_heads * self.head_dim + kv_size = self.to_qkv.num_kv_heads * self.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) + + query = query.unflatten(-1, (self.to_qkv.num_heads, -1)) + key = key.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + value = value.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + add_q_size = self.add_kv_proj.num_heads * self.head_dim + add_kv_size = self.add_kv_proj.num_kv_heads * self.head_dim + encoder_query, encoder_key, encoder_value = encoder_qkv.split( + [add_q_size, add_kv_size, add_kv_size], dim=-1 + ) + + encoder_query = encoder_query.unflatten(-1, (self.add_kv_proj.num_heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.add_kv_proj.num_kv_heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.add_kv_proj.num_kv_heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb # [S, D/2] + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + hidden_states = self.attn( + query, + key, + value, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + # For single-stream blocks, there's no to_out (RowParallelLinear) to handle the reduction + if get_tensor_model_parallel_world_size() > 1: + hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=-1) + return hidden_states + + +class FluxTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + freqs_cis = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + use_real=False, + freqs_dtype=freqs_dtype, + ) + cos_out.append(freqs_cis.real) + sin_out.append(freqs_cis.imag) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class FluxTransformer2DModel(nn.Module): + """ + The Transformer model introduced in Flux. + + Args: + od_config (`OmniDiffusionConfig`): + The configuration for the model. + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `19`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `38`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `4096`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + # the small and frequently-repeated block(s) of a model + # -- typically a transformer layer + # used for torch compile optimizations + _repeated_blocks = ["FluxTransformerBlock"] + + def __init__( + self, + od_config: OmniDiffusionConfig, + patch_size: int = 1, + in_channels: int = 64, + out_channels: int = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = True, + axes_dims_rope: tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + model_config = od_config.tf_model_config + num_layers = model_config.num_layers + self.parallel_config = od_config.parallel_config + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.guidance_embeds = guidance_embeds + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + img_ids: (`torch.Tensor`): + The position ids for image tokens. + txt_ids (`torch.Tensor`): + The position ids for text tokens. + guidance (`torch.Tensor`): + Guidance embeddings for guidance-distilled variant of the model. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + hidden_states = self.x_embedder(hidden_states) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(device=hidden_states.device, dtype=hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self-attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # cross-attn + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + # we need to load the buffers for beta and eps (XIELU) + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + original_name = name + lookup_name = name + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in original_name: + continue + lookup_name = original_name.replace(weight_name, param_name) + param = params_dict[lookup_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if lookup_name not in params_dict and ".to_out.0." in lookup_name: + lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + param = params_dict[lookup_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(original_name) + loaded_params.add(lookup_name) + return loaded_params diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..b90aaa8ca4643c4ff12fcffbf0c278b44bec8f5b --- /dev/null +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -0,0 +1,741 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +import json +import logging +import os +from collections.abc import Iterable +from typing import Any + +import numpy as np +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.flux import FluxTransformer2DModel +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +logger = logging.getLogger(__name__) + + +def get_flux_post_process_func( + od_config: OmniDiffusionConfig, +): + if od_config.output_type == "latent": + return lambda x: x + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxPipeline( + nn.Module, +): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self.device = get_local_device() + model = od_config.model + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = CLIPTextModel.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.text_encoder_2 = T5EncoderModel.from_pretrained( + model, subfolder="text_encoder_2", local_files_only=local_files_only + ) + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + self.transformer = FluxTransformer2DModel(od_config=od_config) + + self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_2 = T5TokenizerFast.from_pretrained( + model, subfolder="tokenizer_2", local_files_only=local_files_only + ) + + self.stage = None + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. + # This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} " + f"but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. " + "Make sure to generate `pooled_prompt_embeds` from the same text encoder " + "that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. " + "Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder " + "that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + dtype: torch.dtype | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(self.device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=self.device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + mu=mu, + ) + return timesteps, num_inference_steps + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def diffuse( + self, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + latents, + latent_image_ids, + text_ids, + negative_text_ids, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + ): + """Diffusion loop with optional image conditioning.""" + self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension and place on same device/dtype as latents + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + self.transformer.do_true_cfg = do_true_cfg # used in teacache hook + # Forward pass for positive prompt (or unconditional if no CFG) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # Forward pass for negative prompt (CFG) + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + return latents + + def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool): + if get_classifier_free_guidance_world_size() == 1: + return True + + if true_cfg_scale <= 1: + logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.") + return False + + if not has_neg_prompt: + logger.warning( + "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings." + ) + return False + return True + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + """Forward pass for flux.""" + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + negative_text_ids = None + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + + # 5. Prepare timesteps + timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.guidance_embeds: + guidance = torch.full([1], guidance_scale, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + latents = self.diffuse( + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + latents, + latent_image_ids, + text_ids, + negative_text_ids, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + ) + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/flux2_klein/__init__.py b/vllm_omni/diffusion/models/flux2_klein/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d477ab0a488339451ec9b96ec26ae6d62384c1c --- /dev/null +++ b/vllm_omni/diffusion/models/flux2_klein/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Flux2 klein diffusion model components.""" + +from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( + Flux2Transformer2DModel, +) +from vllm_omni.diffusion.models.flux2_klein.pipeline_flux2_klein import ( + Flux2KleinPipeline, + get_flux2_klein_post_process_func, +) + +__all__ = [ + "Flux2KleinPipeline", + "Flux2Transformer2DModel", + "get_flux2_klein_post_process_func", +] diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee10d2e0e4d098ac92c2a13eed5d8472f3f17611 --- /dev/null +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -0,0 +1,768 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable +from types import SimpleNamespace +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.embeddings import ( + TimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + + +class Flux2SwiGLU(nn.Module): + """SwiGLU activation used by Flux2.""" + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: float = 3.0, + inner_dim: int | None = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + self.linear_in = MergedColumnParallelLinear( + dim, + [inner_dim, inner_dim], + bias=bias, + return_bias=False, + ) + self.act_fn = Flux2SwiGLU() + self.linear_out = RowParallelLinear( + inner_dim, + dim_out, + bias=bias, + input_is_parallel=True, + return_bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + return self.linear_out(x) + + +class Flux2Attention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + self.added_kv_proj_dim = added_kv_proj_dim + + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=bias, + ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + self.inner_dim, + self.out_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ), + nn.Dropout(dropout), + ] + ) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=added_proj_bias, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + self.to_add_out = RowParallelLinear( + self.inner_dim, + query_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.query_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.kv_num_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + qkv, _ = self.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = encoder_qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.query_num_heads, -1)) + key = key.unflatten(-1, (self.kv_num_heads, -1)) + value = value.unflatten(-1, (self.kv_num_heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.add_query_num_heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.add_kv_num_heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.add_kv_num_heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + if encoder_hidden_states is not None: + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [context_len, hidden_states.shape[1] - context_len], + dim=1, + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + return hidden_states + + +class Flux2ParallelSelfAttention(nn.Module): + """ + Parallel attention block that fuses QKV projections with MLP input projections. + """ + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + self.to_qkv_mlp_proj = ColumnParallelLinear( + self.query_dim, + self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, + bias=bias, + gather_output=True, + ) + self.mlp_act_fn = Flux2SwiGLU() + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = ColumnParallelLinear( + self.inner_dim + self.mlp_hidden_dim, + self.out_dim, + bias=out_bias, + gather_output=True, + ) + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, _ = self.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, + [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], + dim=-1, + ) + + query, key, value = qkv.chunk(3, dim=-1) + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + attn_output = self.attn(query, key, value, attn_metadata) + attn_output = attn_output.flatten(2, 3).to(query.dtype) + + mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1) + hidden_states, _ = self.to_out(hidden_states) + return hidden_states + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + split_hidden_states: bool = False, + text_seq_len: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(len(self.axes_dim)): + freqs_cis = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + use_real=False, + freqs_dtype=freqs_dtype, + ) + cos_out.append(freqs_cis.real) + sin_out.append(freqs_cis.imag) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): + super().__init__() + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + else: + self.guidance_embedder = None + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor | None) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) + + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) + return timesteps_emb + guidance_emb + return timesteps_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2Transformer2DModel(nn.Module): + """ + The Transformer model introduced in Flux 2. + """ + + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: int | None = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + guidance_embeds: bool = True, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.config = SimpleNamespace( + patch_size=patch_size, + in_channels=in_channels, + out_channels=self.out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + timestep_guidance_channels=timestep_guidance_channels, + mlp_ratio=mlp_ratio, + axes_dims_rope=axes_dims_rope, + rope_theta=rope_theta, + eps=eps, + guidance_embeds=guidance_embeds, + ) + + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, + ) + + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, + self.inner_dim, + elementwise_affine=False, + eps=eps, + bias=False, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + img_ids: torch.Tensor, + txt_ids: torch.Tensor, + guidance: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + joint_attention_kwargs = joint_attention_kwargs or {} + + num_txt_tokens = encoder_hidden_states.shape[1] + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for block in self.single_transformer_blocks: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = hidden_states[:, num_txt_tokens:, ...] + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "to_qkvkv_mlp_proj" in name: + name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") + if "to_qkv_mlp_proj" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ef706c3f598d7f3e171d3883d337d1ed0e7591 --- /dev/null +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -0,0 +1,996 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import os +from collections.abc import Callable, Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl_flux2 import AutoencoderKLFlux2 +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils.torch_utils import randn_tensor +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( + Flux2Transformer2DModel, +) +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +logger = init_logger(__name__) + + +class Flux2ImageProcessor(VaeImageProcessor): + """Image processor to preprocess the reference image for Flux2 klein.""" + + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_normalize: bool = True, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_normalize=do_normalize, + do_convert_rgb=do_convert_rgb, + ) + + @staticmethod + def check_image_input( + image: PIL.Image.Image, + max_aspect_ratio: int = 8, + min_side_length: int = 64, + max_area: int = 1024 * 1024, + ) -> PIL.Image.Image: + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + + width, height = image.size + if width < min_side_length or height < min_side_length: + raise ValueError(f"Image too small: {width}x{height}. Both dimensions must be at least {min_side_length}px") + + aspect_ratio = max(width / height, height / width) + if aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Aspect ratio too extreme: {width}x{height} (ratio: {aspect_ratio:.1f}:1). " + f"Maximum allowed ratio is {max_aspect_ratio}:1" + ) + + if width * height > max_area: + logger.warning("Image area exceeds recommended maximum; resizing will be applied.") + + return image + + @staticmethod + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + scale = math.sqrt(target_area / (image_width * image_height)) + width = int(image_width * scale) + height = int(image_height * scale) + return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + @staticmethod + def _resize_if_exceeds_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + if image_width * image_height <= target_area: + return image + return Flux2ImageProcessor._resize_to_target_area(image, target_area) + + def _resize_and_crop(self, image: PIL.Image.Image, width: int, height: int) -> PIL.Image.Image: + image_width, image_height = image.size + left = (image_width - width) // 2 + top = (image_height - height) // 2 + right = left + width + bottom = top + height + return image.crop((left, top, right, bottom)) + + @staticmethod + def concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image: + if len(images) == 1: + return images[0].copy() + + images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] + total_width = sum(img.width for img in images) + max_height = max(img.height for img in images) + background_color = (255, 255, 255) + new_img = PIL.Image.new("RGB", (total_width, max_height), background_color) + + x_offset = 0 + for img in images: + y_offset = (max_height - img.height) // 2 + new_img.paste(img, (x_offset, y_offset)) + x_offset += img.width + + return new_img + + +def get_flux2_klein_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +class Flux2KleinPipeline(nn.Module, CFGParallelMixin, SupportImageInput): + """Flux2 klein pipeline for text-to-image generation.""" + + support_image_input = True + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + is_distilled: bool = False, + ): + super().__init__() + self.od_config = od_config + self.is_distilled = is_distilled + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self._execution_device = get_local_device() + model = od_config.model + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, + subfolder="scheduler", + local_files_only=local_files_only, + ) + self.text_encoder = Qwen3ForCausalLM.from_pretrained( + model, + subfolder="text_encoder", + local_files_only=local_files_only, + ) + self.tokenizer = Qwen2TokenizerFast.from_pretrained( + model, + subfolder="tokenizer", + local_files_only=local_files_only, + ) + self.vae = AutoencoderKLFlux2.from_pretrained( + model, + subfolder="vae", + local_files_only=local_files_only, + ).to(self._execution_device) + + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) + self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + self._guidance_scale = None + self._attention_kwargs = None + self._num_timesteps = None + self._current_timestep = None + self._interrupt = False + + @staticmethod + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_positions = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_positions) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + layer_ids = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, layer_ids) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + "`height` and `width` have to be divisible by %s but are %s and %s. " + "Dimensions will be resized accordingly", + self.vae_scale_factor * 2, + height, + width, + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in ["latents", "prompt_embeds"] for k in callback_on_step_end_tensor_inputs + ): + raise ValueError("`callback_on_step_end_tensor_inputs` must be a subset of ['latents', 'prompt_embeds'].") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if guidance_scale > 1.0 and self.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1 and not self.is_distilled + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + image: PIL.Image.Image | list[PIL.Image.Image] | None = None, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), + ) -> DiffusionOutput: + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list of these): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, + `guidance_scale` is ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. + If not provided, will be generated from "". + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + pass # use image from param list + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at list one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + guidance_scale=guidance_scale, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + if self.do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None and isinstance(prompt, list): + negative_prompt = [negative_prompt] * len(prompt) + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } + if self.do_classifier_free_guidance: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } + else: + negative_kwargs = None + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=self.do_classifier_free_guidance, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + output_slice=output_slice, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + if latents.dtype != self.vae.dtype: + latents = latents.to(self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/glm_image/__init__.py b/vllm_omni/diffusion/models/glm_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8256d8de6b4a158cc5d95932a2d359db4c3f64 --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GLM Image diffusion model components.""" + +from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageKVCache, + GlmImageTransformer2DModel, +) +from vllm_omni.diffusion.models.glm_image.pipeline_glm_image import ( + GlmImagePipeline, + get_glm_image_post_process_func, + get_glm_image_pre_process_func, +) + +__all__ = [ + "GlmImageKVCache", + "GlmImagePipeline", + "GlmImageTransformer2DModel", + "get_glm_image_post_process_func", + "get_glm_image_pre_process_func", +] diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..69475181d282371382848779ebd88f925a144d40 --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -0,0 +1,796 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from enum import Enum +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.cache.base import CachedTransformer +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +logger = init_logger(__name__) + + +class GlmImageImageProjector(nn.Module): + """Projects latent image patches to transformer hidden dimension.""" + + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + # Reshape: [B, C, H, W] -> [B, H', W', C*p*p] -> [B, H'*W', C*p*p] + hidden_states = hidden_states.reshape( + batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class GlmImageRotaryPosEmbed(nn.Module): + """Rotary positional embedding for 2D image patches.""" + + def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(height, device=hidden_states.device) + w_seq = torch.arange(width, device=hidden_states.device) + h_inv_freq = h_inv_freq.to(hidden_states.device) + w_inv_freq = w_inv_freq.to(hidden_states.device) + + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + # Create position matrices: [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1).expand(height, width, -1) + freqs_w = freqs_w.unsqueeze(0).expand(height, width, -1) + + # Concatenate: [height, width, dim//2] -> [height, width, dim] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = torch.cat([freqs, freqs], dim=-1) + freqs = freqs.reshape(height * width, -1) + return (freqs.cos(), freqs.sin()) + + +class GlmImageAdaLayerNormZero(nn.Module): + """Adaptive LayerNorm with zero initialization for both image and text streams.""" + + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) + + emb = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class GlmImageAdaLayerNormContinuous(nn.Module): + """Final AdaLN for output projection (no activation before Linear).""" + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ): + super().__init__() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # NO SiLU here + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class KVCacheMode(Enum): + """Mode for KV cache operations. + + - WRITE: Store the K/V tensors from condition images + - READ: Concatenate cached K/V with current K/V + - SKIP: Do not use cache (pass-through) + """ + + WRITE = "write" + READ = "read" + SKIP = "skip" + + +class GlmImageLayerKVCache: + """KV cache for a single attention layer. + + Stores key and value tensors for image editing. The cache accumulates + KV pairs during write mode and provides them during read mode. + + Shape convention (vllm-omni): + key/value: [batch_size, seq_length, num_heads, head_dim] + """ + + def __init__(self): + self.k_cache: torch.Tensor | None = None + self.v_cache: torch.Tensor | None = None + + def store(self, key: torch.Tensor, value: torch.Tensor) -> None: + """Store or accumulate KV tensors. + + If cache is empty, stores the tensors directly. + If cache is not empty, concatenates new tensors along seq_length dim. + + Args: + key: Key tensor of shape [B, S, H, D] + value: Value tensor of shape [B, S, H, D] + """ + if self.k_cache is None: + self.k_cache = key + self.v_cache = value + else: + # Concatenate along sequence dimension (dim=1 for [B, S, H, D]) + self.k_cache = torch.cat([self.k_cache, key], dim=1) + self.v_cache = torch.cat([self.v_cache, value], dim=1) + + def get(self) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Get cached KV tensors. + + Returns: + Tuple of (k_cache, v_cache), both may be None if cache is empty. + """ + return self.k_cache, self.v_cache + + def clear(self) -> None: + """Clear the cache.""" + self.k_cache = None + self.v_cache = None + + @property + def is_empty(self) -> bool: + """Check if cache is empty.""" + return self.k_cache is None + + def __repr__(self) -> str: + if self.is_empty: + return "GlmImageLayerKVCache(empty)" + return f"GlmImageLayerKVCache(k_shape={self.k_cache.shape}, v_shape={self.v_cache.shape})" + + +class GlmImageKVCache: + """Container for all layers' KV caches. + + Manages KV cache for all transformer layers in GLM-Image model. + Provides a unified interface for setting mode and clearing cache. + + Args: + num_layers: Number of transformer layers in the model. + + Example: + kv_cache = GlmImageKVCache(num_layers=28) + kv_cache.set_mode(KVCacheMode.WRITE) + # ... process condition image ... + kv_cache.set_mode(KVCacheMode.READ) + # ... process target image ... + kv_cache.clear() + """ + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + self._mode: KVCacheMode | None = None + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + """Get cache for a specific layer. + + Args: + layer_idx: Index of the layer (0-indexed). + + Returns: + GlmImageLayerKVCache for the specified layer. + + Raises: + IndexError: If layer_idx is out of range. + """ + if layer_idx < 0 or layer_idx >= self.num_layers: + raise IndexError(f"Layer index {layer_idx} out of range [0, {self.num_layers})") + return self.caches[layer_idx] + + def __len__(self) -> int: + """Return number of layers.""" + return self.num_layers + + @property + def mode(self) -> KVCacheMode | None: + """Get current cache mode.""" + return self._mode + + def set_mode(self, mode: KVCacheMode | str | None) -> None: + """Set cache mode for all layers. + + Args: + mode: Cache mode (WRITE, READ, SKIP) or string ("write", "read", "skip"). + Use None to disable cache operations. + + Raises: + ValueError: If mode is an invalid string. + """ + if mode is None: + self._mode = None + elif isinstance(mode, str): + try: + self._mode = KVCacheMode(mode.lower()) + except ValueError: + raise ValueError(f"Invalid mode: '{mode}', must be one of 'write', 'read', 'skip'") + else: + self._mode = mode + + def clear(self) -> None: + """Clear cache for all layers and reset mode.""" + for cache in self.caches: + cache.clear() + self._mode = None + + @property + def is_empty(self) -> bool: + """Check if all layer caches are empty.""" + return all(cache.is_empty for cache in self.caches) + + def __repr__(self) -> str: + mode_str = self._mode.value if self._mode else "None" + return f"GlmImageKVCache(num_layers={self.num_layers}, mode={mode_str}, is_empty={self.is_empty})" + + +class GlmImageAttention(nn.Module): + """ + Joint attention for GLM-Image model using vllm-omni's optimized attention. + + This combines text and image streams for joint attention computation. + Supports KV caching for image editing workflows via external cache. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + out_bias: bool = True, + eps: float = 1e-5, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + + # QKV projection (fused for efficiency) + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=num_heads, + disable_tp=True, + bias=True, + ) + + # QK normalization (LayerNorm, not RMSNorm for GLM-Image) + self.norm_q = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps) + self.norm_k = nn.LayerNorm(head_dim, elementwise_affine=False, eps=eps) + + # Output projection + self.to_out = nn.Sequential( + nn.Linear(self.inner_dim, dim, bias=out_bias), + nn.Dropout(0.0), + ) + + # RoPE and attention + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + kv_cache: GlmImageLayerKVCache | None = None, + kv_cache_mode: KVCacheMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for joint attention. + + Args: + hidden_states: Image hidden states [B, img_seq_len, D] + encoder_hidden_states: Text hidden states [B, text_seq_len, D] + image_rotary_emb: Tuple of (cos, sin) for RoPE + attention_mask: Optional attention mask for text tokens + kv_cache: Optional layer KV cache for image editing + kv_cache_mode: Cache mode (WRITE, READ, SKIP) + + Returns: + Tuple of (image_hidden_states, text_hidden_states) + """ + dtype = encoder_hidden_states.dtype + batch_size, text_seq_length, _ = encoder_hidden_states.shape + + # Concatenate text and image: [text, image] + hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # QKV projection + qkv, _ = self.to_qkv(hidden_states_combined) + query, key, value = qkv.chunk(3, dim=-1) + + # Reshape: [B, S, H*D] -> [B, S, H, D] + query = query.unflatten(-1, (self.num_heads, -1)) + key = key.unflatten(-1, (self.num_heads, -1)) + value = value.unflatten(-1, (self.num_heads, -1)) + + # QK normalization + query = self.norm_q(query).to(dtype=dtype) + key = self.norm_k(key).to(dtype=dtype) + + # Apply RoPE only to image tokens (not text tokens) + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + # Only apply RoPE to image part (after text_seq_length) + query_img = query[:, text_seq_length:, :, :] + key_img = key[:, text_seq_length:, :, :] + from diffusers.models.embeddings import apply_rotary_emb + + query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + # key_img = self.rope(key_img, cos, sin) + key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) + key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) + + # Handle KV cache for image editing + if kv_cache is not None and kv_cache_mode is not None: + if kv_cache_mode == KVCacheMode.WRITE: + kv_cache.store(key, value) + elif kv_cache_mode == KVCacheMode.READ: + k_cached, v_cached = kv_cache.get() + if k_cached is not None: + key = torch.cat([k_cached, key], dim=1) + value = torch.cat([v_cached, value], dim=1) + # KVCacheMode.SKIP: do nothing + + # Attention computation + hidden_states_out = self.attn(query, key, value) + hidden_states_out = hidden_states_out.flatten(2, 3) + hidden_states_out = hidden_states_out.to(dtype) + + # Output projection + hidden_states_out = self.to_out(hidden_states_out) + + # Split back to text and image + encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] + hidden_states_out = hidden_states_out[:, text_seq_length:, :] + + return hidden_states_out, encoder_hidden_states_out + + +class GlmImageTransformerBlock(nn.Module): + """Single transformer block for GLM-Image.""" + + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ) -> None: + super().__init__() + + # 1. Attention with AdaLN + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) + self.attn1 = GlmImageAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + ) + + # 2. Feedforward + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + kv_cache: GlmImageLayerKVCache | None = None, + kv_cache_mode: KVCacheMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for transformer block. + + Args: + hidden_states: Image hidden states + encoder_hidden_states: Text hidden states + temb: Timestep embedding + image_rotary_emb: RoPE embeddings + attention_mask: Text attention mask + attention_kwargs: Additional attention arguments + kv_cache: Layer-specific KV cache for image editing + kv_cache_mode: Cache mode (WRITE, READ, SKIP) + + Returns: + Tuple of (image_hidden_states, text_hidden_states) + """ + # 1. Timestep conditioning via AdaLN + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, temb) + + # 2. Attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + kv_cache_mode=kv_cache_mode, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) + + # 3. Feedforward + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class GlmImageTransformer2DModel(CachedTransformer): + """ + GLM-Image Transformer model for 2D image generation. + + This is the vllm-omni optimized version of the GLM-Image DiT model. + + Args: + od_config: OmniDiffusionConfig containing model configuration. + Transformer hyper-parameters (e.g. patch size / channels / heads) are read from + `od_config.tf_model_config`. + """ + + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + } + + def __init__( + self, + od_config: OmniDiffusionConfig, + ): + super().__init__() + + patch_size = od_config.tf_model_config.patch_size + in_channels = od_config.tf_model_config.in_channels + out_channels = od_config.tf_model_config.out_channels + num_attention_heads = od_config.tf_model_config.num_attention_heads + attention_head_dim = od_config.tf_model_config.attention_head_dim + time_embed_dim = od_config.tf_model_config.time_embed_dim + condition_dim = od_config.tf_model_config.condition_dim + prior_vq_quantizer_codebook_size = od_config.tf_model_config.prior_vq_quantizer_codebook_size + text_embed_dim = od_config.tf_model_config.text_embed_dim + + # Get num_layers from config if available + model_config = od_config.tf_model_config + if model_config is not None and hasattr(model_config, "num_layers"): + num_layers = model_config.num_layers + + self.od_config = od_config + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords + pooled_projection_dim = 2 * 2 * condition_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. RoPE + self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + + # 2. Patch & Text-timestep embedding + self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) + self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") + + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=time_embed_dim, + ) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + for _ in range(num_layers) + ] + ) + + # 4. Output projection + self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + prior_token_id: torch.Tensor, + prior_token_drop: torch.Tensor, + timestep: torch.LongTensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + kv_cache: GlmImageKVCache | None = None, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + Forward pass of the GLM-Image Transformer. + + Args: + hidden_states: Input latent tensor of shape [B, C, H, W]. + encoder_hidden_states: Text embeddings of shape [B, S, D]. + prior_token_id: Prior VQ token IDs. + prior_token_drop: Mask for dropping prior tokens (CFG). + timestep: Diffusion timestep. + target_size: Target image size for conditioning. + crop_coords: Crop coordinates for conditioning. + attention_kwargs: Additional attention arguments. + return_dict: Whether to return a dataclass. + attention_mask: Optional attention mask for text tokens. + image_rotary_emb: Pre-computed rotary embeddings. + kv_cache: Optional KV cache for image editing. When provided, + the cache's mode determines behavior: + - WRITE: Store KV from condition images + - READ: Use cached KV during generation + - SKIP: No caching (same as None) + + Returns: + Output tensor or Transformer2DModelOutput. + """ + batch_size, num_channels, height, width = hidden_states.shape + + # Get KV cache mode + kv_cache_mode = kv_cache.mode if kv_cache is not None else None + + # 1. RoPE + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) + # Move to correct device + image_rotary_emb = ( + image_rotary_emb[0].to(hidden_states.device), + image_rotary_emb[1].to(hidden_states.device), + ) + + # 2. Patch & Timestep embeddings + p = self.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = self.image_projector(hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) + + # Prior embedding with dropout + prior_embedding = self.prior_token_embedding(prior_token_id) + prior_embedding[prior_token_drop] *= 0.0 + prior_hidden_states = self.prior_projector(prior_embedding) + hidden_states = hidden_states + prior_hidden_states + + # Timestep conditioning + temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) + + # 3. Transformer blocks + for layer_idx, block in enumerate(self.transformer_blocks): + # Get layer-specific KV cache if available + layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None + + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_cache=layer_kv_cache, + kv_cache_mode=kv_cache_mode, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W] + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from pretrained checkpoint. + + This method handles the mapping from diffusers weight names to vllm-omni weight names, + especially for fused QKV projections. + """ + stacked_params_mapping = [ + # Fused QKV projection: to_q, to_k, to_v -> to_qkv + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + ] + + params_dict = dict(self.named_parameters()) + + # Also include buffers (for any beta/eps parameters) + for name, buffer in self.named_buffers(): + params_dict[name] = buffer + + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Handle fused QKV projections + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + # Map diffusers name to vllm-omni name + name = name.replace(weight_name, param_name) + + if name not in params_dict: + logger.warning(f"Skipping weight {name} - not found in model") + break + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Standard weight loading (not fused) + if name not in params_dict: + logger.warning(f"Skipping weight {name} - not found in model") + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + + return loaded_params + + def create_kv_cache(self) -> GlmImageKVCache: + """ + Create a KV cache for image editing. + + Returns a new GlmImageKVCache instance sized for this model's + number of transformer layers. Use this for image editing workflows. + + Example: + kv_cache = transformer.create_kv_cache() + kv_cache.set_mode("write") + transformer(condition_image, kv_cache=kv_cache) + kv_cache.set_mode("read") + for t in timesteps: + transformer(noisy_target, kv_cache=kv_cache) + kv_cache.clear() + + Returns: + GlmImageKVCache instance with correct number of layers. + """ + return GlmImageKVCache(num_layers=len(self.transformer_blocks)) + + @property + def num_layers(self) -> int: + """Return number of transformer layers.""" + return len(self.transformer_blocks) + + @property + def dtype(self) -> torch.dtype: + """Return dtype of model parameters.""" + return next(self.parameters()).dtype diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c167c32a97a85b317a457e02b88ecf88ba49f7 --- /dev/null +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -0,0 +1,1015 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GlmImagePipeline implementation for vLLM-Omni. + +This pipeline implements GLM-Image text-to-image generation with: +- AR stage: GlmImageForConditionalGeneration generates prior tokens +- DiT stage: GlmImageTransformer2DModel performs diffusion denoising +- VAE: AutoencoderKL decodes latents to images +""" + +from __future__ import annotations + +import inspect +import json +import logging +import os +import re +from collections.abc import Iterable +from typing import cast + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import ( + ByT5Tokenizer, + GlmImageForConditionalGeneration, + GlmImageProcessor, + T5EncoderModel, +) + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.glm_image.glm_image_transformer import ( + GlmImageKVCache, + GlmImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +def get_glm_image_pre_process_func(od_config: OmniDiffusionConfig): + """Get pre-processing function for GLM-Image pipeline. + + Pre-processes condition images before they are sent to the pipeline. + This is called by DiffusionEngine before batching requests. + """ + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + block_out_channels = vae_config.get("block_out_channels", [128, 256, 512, 512]) + vae_scale_factor = 2 ** (len(block_out_channels) - 1) + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + # GLM-Image uses patch_size=2 for transformer + patch_size = 2 + + def pre_process_func(request: OmniDiffusionRequest): + """Pre-process condition images for Image Edit mode.""" + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + # Text-to-image mode, no preprocessing needed + continue + + if not isinstance(raw_image, list): + raw_image = [raw_image] + images = [ + PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im) + for im in raw_image + ] + + preprocessed = [] + height, width = None, None + + for img in images: + if isinstance(img, PIL.Image.Image): + img_h, img_w = img.size[::-1] # PIL is (width, height) + else: + img_h, img_w = img.shape[:2] + + # Align to multiple of vae_scale_factor * patch_size + multiple_of = vae_scale_factor * patch_size + img_h = (img_h // multiple_of) * multiple_of + img_w = (img_w // multiple_of) * multiple_of + + processed = image_processor.preprocess(img, height=img_h, width=img_w) + preprocessed.append(processed) + + # Use first image dimensions as default + if height is None: + height, width = img_h, img_w + + # Store in request + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt, additional_information={}) + elif "additional_information" not in prompt: + prompt["additional_information"] = {} + prompt["additional_information"]["preprocessed_image"] = processed # type: ignore + prompt["additional_information"]["prompt_image"] = images # type: ignore + request.prompts[i] = prompt + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + return request + + return pre_process_func + + +def get_glm_image_post_process_func(od_config: OmniDiffusionConfig): + """Get post-processing function for GLM-Image pipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + block_out_channels = vae_config.get("block_out_channels", [128, 256, 512, 512]) + vae_scale_factor = 2 ** (len(block_out_channels) - 1) + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + def post_process_func(images: torch.Tensor) -> list[PIL.Image.Image]: + return image_processor.postprocess(images, output_type="pil") + + return post_process_func + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + """Calculate timestep shift based on image sequence length.""" + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps. + Handles custom timesteps and sigmas schedules. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + # Both provided - check if scheduler supports both + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +def retrieve_latents( + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", +) -> torch.Tensor: + """Extract latents from VAE encoder output.""" + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImagePipeline(nn.Module): + """ + GLM-Image Pipeline for text-to-image and image-to-image generation. + + This pipeline integrates: + - AR model (GlmImageForConditionalGeneration): Generates prior image tokens + - Text encoder (T5EncoderModel): Encodes glyph/text embeddings + - DiT model (GlmImageTransformer2DModel): Diffusion transformer + - VAE (AutoencoderKL): Encodes/decodes images to/from latent space + + The pipeline flow: + 1. AR generates prior_token_ids from text prompt + 2. T5 encodes glyph text for text rendering + 3. DiT performs iterative denoising conditioned on prior tokens + 4. VAE decodes final latents to image + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config + self.device = get_local_device() + + model = od_config.model + local_files_only = os.path.exists(model) + + if local_files_only: + model_path = model + else: + model_path = download_weights_from_hf_specific(model, od_config.revision, ["*"]) + + # Load scheduler + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model_path, subfolder="scheduler", local_files_only=True + ) + + # Load AR model (vision_language_encoder) + logger.info("Loading GlmImageForConditionalGeneration (AR model)...") + self.vision_language_encoder = GlmImageForConditionalGeneration.from_pretrained( + model_path, + subfolder="vision_language_encoder", + local_files_only=True, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.vision_language_encoder.eval() + + # Load processor for AR model + self.processor = GlmImageProcessor.from_pretrained(model_path, subfolder="processor", local_files_only=True) + + # Load text encoder (T5 for glyph embeddings) + logger.info("Loading T5EncoderModel (glyph encoder)...") + self.text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + local_files_only=True, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.text_encoder.eval() + + # Load tokenizer for glyph encoding + self.tokenizer = ByT5Tokenizer.from_pretrained(model_path, subfolder="tokenizer", local_files_only=True) + + # Load VAE + logger.info("Loading AutoencoderKL (VAE)...") + self.vae = AutoencoderKL.from_pretrained( + model_path, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16 + ).to(self.device) + self.vae.eval() + + # Load transformer (DiT) + logger.info("Loading GlmImageTransformer2DModel (DiT)...") + self.transformer = GlmImageTransformer2DModel(od_config=od_config) + + # Weight sources for DiT loading + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=od_config.revision, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + # Configure scale factors + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 128 + + # Get transformer config for patch size + self._patch_size = getattr(self.transformer, "patch_size", 2) + + # ==================== Input Validation ==================== + + def check_inputs( + self, + prompt: str | list[str] | None, + height: int | None, + width: int | None, + prompt_embeds: torch.Tensor | None = None, + ) -> None: + """Validate input arguments before generation.""" + # Check dimension alignment + multiple_of = self.vae_scale_factor * self._patch_size + if height is not None and height % multiple_of != 0: + logger.warning( + f"`height` should be divisible by {multiple_of} but is {height}. " + "Dimensions will be adjusted accordingly." + ) + if width is not None and width % multiple_of != 0: + logger.warning( + f"`width` should be divisible by {multiple_of} but is {width}. Dimensions will be adjusted accordingly." + ) + + # Check prompt/prompt_embeds mutual exclusivity + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please provide only one of the two." + ) + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.") + + # Check prompt type + if prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` must be of type `str` or `list` but is {type(prompt)}") + + # ==================== AR Stage Methods ==================== + + @staticmethod + def _compute_generation_params( + image_grid_thw: torch.Tensor, + is_text_to_image: bool, + ) -> tuple[int, int, int, int]: + """ + Compute AR generation parameters from image grid. + + Args: + image_grid_thw: Image grid tensor of shape [N, 3] where each row is [t, h, w] + is_text_to_image: Whether this is text-to-image (vs image-to-image) + + Returns: + Tuple of (max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w) + """ + grid_sizes = [] + grid_hw = [] + + for i in range(image_grid_thw.shape[0]): + t, h, w = image_grid_thw[i].tolist() + grid_sizes.append(int(h * w)) + grid_hw.append((int(h), int(w))) + + if not is_text_to_image: + # Image-to-image: only generate target image tokens + max_new_tokens = grid_sizes[-1] + 1 + large_image_start_offset = 0 + target_grid_h, target_grid_w = grid_hw[-1] + else: + # Text-to-image: generate both small preview and large target + total_tokens = sum(grid_sizes) + max_new_tokens = total_tokens + 1 + large_image_start_offset = sum(grid_sizes[1:]) + target_grid_h, target_grid_w = grid_hw[0] + + return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w + + @staticmethod + def _extract_large_image_tokens( + outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + ) -> torch.Tensor: + """Extract large image tokens from AR output.""" + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + @staticmethod + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + """Upsample token IDs by 2x using nearest neighbor interpolation.""" + token_ids = token_ids.view(1, 1, token_h, token_w) + token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( + dtype=torch.long + ) + token_ids = token_ids.view(1, -1) + return token_ids + + @torch.inference_mode() + def generate_prior_tokens( + self, + prompt: str, + height: int, + width: int, + image: list[PIL.Image.Image] | None = None, + factor: int = 32, + ) -> tuple[torch.Tensor, list[torch.Tensor] | None]: + """ + Generate prior tokens using the AR model. + + Args: + prompt: Text prompt for generation + height: Target image height + width: Target image width + image: Optional condition images for image-to-image + factor: Token factor (default 32) + + Returns: + Tuple of (prior_token_ids, prior_token_image_ids) + prior_token_image_ids is a list of tensors, one per condition image + """ + device = self.vision_language_encoder.device + height = (height // factor) * factor + width = (width // factor) * factor + is_text_to_image = image is None or len(image) == 0 + + # Build message content + content = [] + if image is not None: + for img in image: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": prompt}) + messages = [{"role": "user", "content": content}] + + # Apply chat template - processor will handle target dimensions and build grid + inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + target_h=height, + target_w=width, + return_dict=True, + return_tensors="pt", + ).to(device) + + image_grid_thw = inputs.get("image_grid_thw") + + # Compute generation parameters from the full grid + max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params( + image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image + ) + + # Process condition images if provided + # Use image_grid_thw[:-1] to exclude the target image grid (last entry) + prior_token_image_ids = None + if image is not None and image_grid_thw is not None and len(image_grid_thw) > 1: + # Get features only for condition images (exclude target image grid) + condition_grid = image_grid_thw[:-1] + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], condition_grid + ).pooler_output + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + flat_prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, condition_grid + ) + # Split by image grid sizes and convert to list + split_sizes = (condition_grid.prod(dim=-1)).tolist() + prior_token_image_ids_list = torch.split(flat_prior_token_image_ids, split_sizes, dim=0) + # Convert to list with upsampling + prior_token_image_ids = [] + for i, token_ids in enumerate(prior_token_image_ids_list): + grid_t, grid_h, grid_w = condition_grid[i].tolist() + token_ids = token_ids.view(1, -1) + # Upsample 2x (from d32 to d64) + token_ids_upsampled = self._upsample_token_ids(token_ids, grid_h, grid_w) + prior_token_image_ids.append(token_ids_upsampled) + + # Generate with AR model + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + # Extract and upsample tokens + large_image_tokens = token_h * token_w + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs, inputs["input_ids"].shape[-1], large_image_offset, large_image_tokens + ) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + + return prior_token_ids, prior_token_image_ids + + # ==================== Text Encoding Methods ==================== + + def get_glyph_texts(self, prompt: str | list[str]) -> list[str]: + """Extract text within quotes for glyph rendering.""" + prompt = prompt[0] if isinstance(prompt, list) else prompt + ocr_texts = ( + re.findall(r"'([^']*)'", prompt) + + re.findall(r"“([^“”]*)”", prompt) + + re.findall(r'"([^"]*)"', prompt) + + re.findall(r"「([^「」]*)」", prompt) + ) + return ocr_texts + + def _get_glyph_embeds( + self, + prompt: str | list[str], + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + """Get glyph embeddings from T5 encoder for text rendering.""" + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + glyph_texts = self.get_glyph_texts(prompt) + input_ids = self.tokenizer( + glyph_texts if len(glyph_texts) > 0 else [""], + max_length=max_sequence_length, + truncation=True, + ).input_ids + + # Pad to even length + input_ids = [[self.tokenizer.pad_token_id] * ((len(ids) + 1) % 2) + ids for ids in input_ids] + max_length = max(len(ids) for ids in input_ids) + + attention_mask = torch.tensor( + [[1] * len(ids) + [0] * (max_length - len(ids)) for ids in input_ids], + device=device, + ) + input_ids = torch.tensor( + [ids + [self.tokenizer.pad_token_id] * (max_length - len(ids)) for ids in input_ids], + device=device, + ) + + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 2048, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Encode prompt into glyph embeddings for text rendering.""" + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = [""] * batch_size + negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # ==================== Latent Preparation ==================== + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """Prepare random noise latents.""" + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError(f"Passed {len(generator)} generators but batch size is {batch_size}.") + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def diffuse( + self, + latents: torch.Tensor, + prior_token_id: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + timesteps: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + guidance_scale: float, + do_classifier_free_guidance: bool, + kv_caches: GlmImageKVCache | None = None, + ) -> torch.Tensor: + """ + Denoising loop for diffusion process with CFG-Parallel support. + + Args: + latents: Initial noise latents + prior_token_id: Prior tokens generated by AR model + prompt_embeds: Encoded positive prompt embeddings (glyph embeddings) + negative_prompt_embeds: Encoded negative prompt embeddings + timesteps: Denoising timesteps + target_size: Target image size tensor [[height, width]] + crop_coords: Crop coordinates tensor + guidance_scale: CFG scale + do_classifier_free_guidance: Whether to apply CFG + kv_caches: Optional KV cache for Image Edit mode + + Returns: + Denoised latents ready for VAE decode + """ + # Prepare conditional/unconditional drop flags + prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool) + + transformer_dtype = self.transformer.dtype + + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_classifier_free_guidance and get_classifier_free_guidance_world_size() > 1 + + for i, t in enumerate(timesteps): + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) - 1 + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_rank == 0: + # Rank 0: Compute positive (conditional) prediction + local_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_cache=kv_caches, + return_dict=False, + )[0].float() + else: + # Rank 1: Compute negative (unconditional) prediction + local_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_cache=kv_caches, + return_dict=False, + )[0].float() + + # All-gather predictions from all ranks + gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + + if cfg_rank == 0: + # Rank 0: Combine predictions and apply CFG + noise_pred_cond = gathered[0] + noise_pred_uncond = gathered[1] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Broadcast updated latents to all ranks + cfg_group.broadcast(latents, src=0) + + else: + # Sequential CFG (single GPU or no CFG) + # Conditional forward pass + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_cache=kv_caches, + return_dict=False, + )[0].float() + + if do_classifier_free_guidance: + # Unconditional forward pass + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crop_coords, + kv_cache=kv_caches, + return_dict=False, + )[0].float() + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + return latents + + # ==================== Main Forward Pass ==================== + + def _prepare_condition_image_kv_cache( + self, + condition_images: list[torch.Tensor], + prior_token_image_ids: list[torch.Tensor], + prompt_embeds: torch.Tensor, + generator: torch.Generator | None = None, + ) -> GlmImageKVCache: + """ + Prepare KV cache by running condition images through transformer at timestep 0. + + This is used for Image Edit mode where we need to cache the condition image's + KV states for cross-attention during denoising. + + Args: + condition_images: List of preprocessed condition images + prior_token_image_ids: Prior token IDs for each condition image from AR model + prompt_embeds: Prompt embeddings (used to get dtype) + generator: Optional random generator + + Returns: + GlmImageKVCache with cached KV states from condition images + """ + kv_caches = self.transformer.create_kv_cache() + kv_caches.set_mode("write") + + # Prepare VAE normalization parameters + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(device=self.device, dtype=prompt_embeds.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(device=self.device, dtype=prompt_embeds.dtype) + ) + + # Process each condition image through transformer to populate KV cache + for condition_image, condition_prior_token_id in zip(condition_images, prior_token_image_ids): + condition_image = condition_image.to(device=self.device, dtype=prompt_embeds.dtype) + + # Encode condition image to latent space + # Use argmax (mode) for deterministic encoding of condition images + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + # Run forward pass at timestep 0 to cache KV states + # Empty encoder_hidden_states since we only want to cache image features + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], + prior_token_id=condition_prior_token_id, + prior_token_drop=torch.full_like(condition_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=self.device), + target_size=torch.tensor([condition_image.shape[-2:]], device=self.device, dtype=prompt_embeds.dtype), + crop_coords=torch.zeros((1, 2), device=self.device, dtype=prompt_embeds.dtype), + kv_cache=kv_caches, + return_dict=False, + ) + + return kv_caches + + @torch.inference_mode() + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + """ + Main generation forward pass. + + Args: + req: OmniDiffusionRequest with generation parameters + + Returns: + DiffusionOutput containing generated image + """ + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + + # Get pre-computed prompt embeddings if provided + if isinstance(first_prompt, str): + prompt_embeds = None + else: + prompt_embeds = first_prompt.get("prompt_embeds") + if not isinstance(prompt_embeds, torch.Tensor): + prompt_embeds = None + + # Get condition images for Image Edit mode + # Use pre-processed images from pre_process_func + preprocessed_images = ( + None + if isinstance(first_prompt, str) + else [first_prompt.get("additional_information", {}).get("preprocessed_image")] + ) + condition_images = ( + None + if isinstance(first_prompt, str) + else first_prompt.get("additional_information", {}).get("prompt_image") + ) + img_height = req.sampling_params.height + img_width = req.sampling_params.width + + is_image_edit = preprocessed_images is not None + + # Use image dimensions as default if available + height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or 50 + guidance_scale = req.sampling_params.guidance_scale or 1.5 + + # 0. Validate inputs + self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) + + batch_size = 1 + do_classifier_free_guidance = guidance_scale > 1.0 + + # Set seed if provided + generator = None + if req.sampling_params.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) + + # 1. Get prior tokens - either from external source (multistage) or generate internally + # Check if prior_token_ids are provided externally (from AR stage in multistage mode) + external_prior_tokens = req.sampling_params.extra_args.get("prior_token_ids") + external_prior_image_ids = req.sampling_params.extra_args.get("prior_token_image_ids") + + if external_prior_tokens is not None: + # Multistage mode: use externally provided prior tokens from vLLM AR stage + logger.info("Using externally provided prior tokens from AR stage...") + prior_token_id = external_prior_tokens + if isinstance(prior_token_id, list): + prior_token_id = torch.tensor(prior_token_id, dtype=torch.long, device=self.device) + elif isinstance(prior_token_id, torch.Tensor): + prior_token_id = prior_token_id.to(device=self.device, dtype=torch.long) + # Ensure shape is [1, num_tokens] for batch processing + if prior_token_id.dim() == 1: + prior_token_id = prior_token_id.unsqueeze(0) + prior_token_image_ids = external_prior_image_ids + else: + # Single-stage mode: generate prior tokens with internal AR model + logger.info("Generating prior tokens with AR model...") + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prompt=prompt, + image=condition_images, + height=height, + width=width, + ) + + # 2. Encode prompt for glyph embeddings + logger.info("Encoding prompt...") + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=1, + prompt_embeds=prompt_embeds, + device=self.device, + dtype=self.transformer.dtype, + ) + + # 3. Prepare KV cache for Image Edit mode + kv_caches = None + if is_image_edit and prior_token_image_ids is not None: + logger.info("Preparing KV cache for Image Edit mode...") + kv_caches = self._prepare_condition_image_kv_cache( + condition_images=preprocessed_images, + prior_token_image_ids=prior_token_image_ids, + prompt_embeds=prompt_embeds, + generator=generator, + ) + # Switch to read mode for denoising + kv_caches.set_mode("read") + + # 4. Prepare latents + latent_channels = self.transformer.in_channels + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=self.device, + generator=generator, + ) + + # 5. Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (self._patch_size**2) + timesteps_array = np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1] + timesteps_array = timesteps_array.astype(np.int64).astype(np.float32) + sigmas = timesteps_array / self.scheduler.config.num_train_timesteps + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, self.device, timesteps_array.tolist(), sigmas.tolist(), mu=mu + ) + + # 6. Prepare conditioning tensors + target_size = torch.tensor([[height, width]], dtype=prompt_embeds.dtype, device=self.device) + crop_coords = torch.zeros((1, 2), dtype=prompt_embeds.dtype, device=self.device) + + # 7. Denoising loop with CFG-parallel support + logger.info(f"Starting denoising loop with {num_inference_steps} steps...") + latents = self.diffuse( + latents=latents, + prior_token_id=prior_token_id, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + timesteps=timesteps, + target_size=target_size, + crop_coords=crop_coords, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + kv_caches=kv_caches, + ) + + # 8. VAE decode + logger.info("Decoding latents with VAE...") + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + + # 9. Leave post-process to vllm-omni pipeline + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load transformer weights.""" + # Filter weights for transformer only + transformer_weights = ( + (name.replace("transformer.", ""), weight) for name, weight in weights if name.startswith("transformer.") + ) + return self.transformer.load_weights(transformer_weights) diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..abdfe1e50a44503db4ec24e1a5c0d5c2469cd5e4 --- /dev/null +++ b/vllm_omni/diffusion/models/interface.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import ( + ClassVar, + Protocol, + runtime_checkable, +) + + +@runtime_checkable +class SupportImageInput(Protocol): + support_image_input: ClassVar[bool] = True + color_format: ClassVar[str] = "RGB" # Default color format + + +@runtime_checkable +class SupportAudioOutput(Protocol): + support_audio_output: ClassVar[bool] = True diff --git a/vllm_omni/diffusion/models/longcat_image/__init__.py b/vllm_omni/diffusion/models/longcat_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..105f51a26124874e8b54422fee836bb98a4bde20 --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel +from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import ( + LongCatImagePipeline, + get_longcat_image_post_process_func, +) + +__all__ = [ + "LongCatImagePipeline", + "LongCatImageTransformer2DModel", + "get_longcat_image_post_process_func", +] diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6d54aef7182ac2129879e7945ef2b69f63621c --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -0,0 +1,757 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) +from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, dim_out: int | None = None, mult: int = 4, bias: bool = True): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.w_in = ColumnParallelLinear(dim, inner_dim, bias=bias, return_bias=False) + self.act = get_act_fn("gelu_pytorch_tanh") + self.w_out = RowParallelLinear(inner_dim, dim_out, bias=bias, return_bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.w_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.w_out(hidden_states) + return hidden_states + + +class LongCatImageAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: bool | None = None, + pre_only: bool = False, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + # Fused QKV projection using vLLM's optimized layer + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=bias, + ) + + if not self.pre_only: + self.to_out = RowParallelLinear(self.inner_dim, self.out_dim, bias=out_bias) + + if self.added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + + self.add_kv_proj = QKVParallelLinear( + hidden_size=self.added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=added_proj_bias, + ) + + self.to_add_out = RowParallelLinear(self.inner_dim, query_dim, bias=out_bias) + + self.attn = Attention( + num_heads=heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def _sp_attention_with_rope( + self, + img_query: torch.Tensor, + img_key: torch.Tensor, + img_value: torch.Tensor, + text_query: torch.Tensor, + text_key: torch.Tensor, + text_value: torch.Tensor, + text_seq_len: int, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None, + ) -> torch.Tensor: + """ + Apply RoPE separately to text and image Q/K, then run SP attention with joint tensors. + + This is the common SP attention pattern used by both dual-stream (added_kv_proj_dim) + and single-stream (no added_kv_proj_dim) blocks. + + Args: + img_query/key/value: Image Q/K/V tensors (chunked in SP mode) + text_query/key/value: Text Q/K/V tensors (full, not chunked) + text_seq_len: Length of text sequence for splitting RoPE + image_rotary_emb: (freqs_cos, freqs_sin) containing [txt_pos, img_pos] + + Returns: + Attention output with shape (B, txt_len + img_len/SP, H, D) + """ + if image_rotary_emb is not None: + freqs_cos, freqs_sin = image_rotary_emb + txt_rotary_emb = (freqs_cos[:text_seq_len], freqs_sin[:text_seq_len]) + img_rotary_emb_split = (freqs_cos[text_seq_len:], freqs_sin[text_seq_len:]) + # Apply RoPE to image Q/K + img_query = apply_rotary_emb(img_query, img_rotary_emb_split, sequence_dim=1) + img_key = apply_rotary_emb(img_key, img_rotary_emb_split, sequence_dim=1) + # Apply RoPE to text Q/K + text_query = apply_rotary_emb(text_query, txt_rotary_emb, sequence_dim=1) + text_key = apply_rotary_emb(text_key, txt_rotary_emb, sequence_dim=1) + + return self.attn( + img_query, + img_key, + img_value, + AttentionMetadata( + joint_query=text_query, + joint_key=text_key, + joint_value=text_value, + joint_strategy="front", + ), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with SP-aware joint attention. + + Input shapes (in SP mode): + - hidden_states: (B, img_seq_len // SP, D) - image hidden states (chunked) + - encoder_hidden_states: (B, txt_seq_len, D) - text hidden states (full) + + SP Mode (sequence_parallel_size > 1): + - Image Q/K/V: processed with AllToAll or Ring communication + - Text Q/K/V: passed as joint tensors, broadcasted to all ranks + - Output: attention over (text + image) with proper SP handling + + Non-SP Mode (sequence_parallel_size = 1): + - Standard concatenation of text + image Q/K/V + - Regular attention over the full sequence + """ + qkv, _ = self.to_qkv(hidden_states) + + q_size = self.to_qkv.num_heads * self.head_dim + kv_size = self.to_qkv.num_kv_heads * self.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) + + query = query.unflatten(-1, (self.to_qkv.num_heads, -1)) + key = key.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + value = value.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + q_size = self.add_kv_proj.num_heads * self.head_dim + kv_size = self.add_kv_proj.num_kv_heads * self.head_dim + encoder_query, encoder_key, encoder_value = encoder_qkv.split([q_size, kv_size, kv_size], dim=-1) + + encoder_query = encoder_query.unflatten(-1, (self.add_kv_proj.num_heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.add_kv_proj.num_kv_heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.add_kv_proj.num_kv_heads, -1)) + + # Apply RMSNorm to text Q/K + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + # Check if SP is enabled from forward context (set by LongCatImageTransformer2DModel) + forward_ctx = get_forward_context() + sp_size = forward_ctx.sequence_parallel_size + use_sp_joint_attention = sp_size > 1 and not forward_ctx.split_text_embed_in_sp + + if use_sp_joint_attention: + # SP Mode: Use common helper for RoPE + joint attention + hidden_states = self._sp_attention_with_rope( + img_query=query, + img_key=key, + img_value=value, + text_query=encoder_query, + text_key=encoder_key, + text_value=encoder_value, + text_seq_len=encoder_query.shape[1], + image_rotary_emb=image_rotary_emb, + ) + else: + # Non-SP Mode: Concat first, then apply RoPE to full sequence + joint_query = torch.cat([encoder_query, query], dim=1) + joint_key = torch.cat([encoder_key, key], dim=1) + joint_value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + # Apply RoPE to full (text + image) sequence + joint_query = apply_rotary_emb(joint_query, image_rotary_emb, sequence_dim=1) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb, sequence_dim=1) + + hidden_states = self.attn( + joint_query, + joint_key, + joint_value, + ) + else: + # No added_kv_proj_dim: single stream attention (e.g., from SingleTransformerBlock) + # hidden_states is the combined (text + image) sequence + # In SP mode, image part is chunked: (B, txt_len + img_len/SP, D) + + # Check if SP is enabled and we have text_seq_len info + forward_ctx = get_forward_context() + sp_size = forward_ctx.sequence_parallel_size + text_seq_len = kwargs.get("text_seq_len", None) + use_sp_single_stream = sp_size > 1 and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None + + if use_sp_single_stream: + # SP Mode for single-stream block: + # Split QKV into text and image parts, then use common helper + hidden_states = self._sp_attention_with_rope( + img_query=query[:, text_seq_len:], + img_key=key[:, text_seq_len:], + img_value=value[:, text_seq_len:], + text_query=query[:, :text_seq_len], + text_key=key[:, :text_seq_len], + text_value=value[:, :text_seq_len], + text_seq_len=text_seq_len, + image_rotary_emb=image_rotary_emb, + ) + else: + # Non-SP Mode: standard path + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = self.attn( + query, + key, + value, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split output back into text and image portions + # In SP mode: seq_len = txt_seq_len + img_seq_len // SP + # In non-SP mode: seq_len = txt_seq_len + img_seq_len + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states, _ = self.to_out(hidden_states) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + # For single-stream blocks, there's no to_out (RowParallelLinear) to handle the reduction + if get_tensor_model_parallel_world_size() > 1: + hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=-1) + return hidden_states + + +class LongCatImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = LongCatImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class LongCatImagePosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class LongCatImageTimestepEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + return timesteps_emb + + +class LongCatImageSingleTransformerBlock(nn.Module): + """ + Single-stream Transformer block for LongCat with SP (Sequence Parallelism) support. + + SP handling is delegated to LongCatImageAttention via the text_seq_len parameter. + This keeps the block logic clean and centralizes SP logic in the attention layer. + """ + + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + # SP handling is delegated to LongCatImageAttention via text_seq_len kwarg + self.attn = LongCatImageAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for SingleTransformerBlock with SP support. + + SP handling is delegated to LongCatImageAttention.forward via text_seq_len kwarg. + This keeps the block logic clean and centralizes SP logic in the attention layer. + """ + text_seq_len = encoder_hidden_states.shape[1] + + # Concatenate text and image + # In SP mode: image is chunked (B, img_len/SP, D), text is full (B, txt_len, D) + combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) + residual = combined + norm_hidden_states, gate = self.norm(combined, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + # Delegate SP handling to LongCatImageAttention by passing text_seq_len + # LongCatImageAttention will detect SP mode and handle text/image splitting internally + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + text_seq_len=text_seq_len, # Pass text_seq_len for SP mode handling + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class LongCatImageTransformer2DModel(nn.Module): + """ + The Transformer model introduced in Flux. + + Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig. + """ + + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + + def __init__( + self, + od_config: OmniDiffusionConfig, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + pooled_projection_dim: int = 3584, + axes_dims_rope: list[int] = [16, 56, 56], + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.pooled_projection_dim = pooled_projection_dim + + # Store parallel config for SP support + self.parallel_config = od_config.parallel_config + + self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_embed = LongCatImageTimestepEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + LongCatImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + LongCatImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + self.use_checkpoint = [True] * num_layers + self.use_single_checkpoint = [True] * num_single_layers + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> torch.FloatTensor | Transformer2DModelOutput: + # Before: hidden_states shape = (B, img_seq_len, in_channels) + # After: hidden_states shape = (B, img_seq_len // SP, in_channels) + sp_size = self.parallel_config.sequence_parallel_size + # Store SP size in forward context for sub-modules to access + get_forward_context().sequence_parallel_size = sp_size + if sp_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + original_shape = hidden_states.shape + hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] + # LongCat uses dual-stream (text + image) with joint attention + # Text embeddings should be replicated across SP ranks for correctness + get_forward_context().split_text_embed_in_sp = False + # Debug log (only first forward) + if not hasattr(self, "_sp_forward_logged"): + self._sp_forward_logged = True + logger.info( + f"[LongCat Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " + f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" + ) + else: + if not hasattr(self, "_sp_forward_logged"): + self._sp_forward_logged = True + logger.info(f"[LongCat Transformer] SP disabled: sp_size={sp_size}") + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + temb = self.time_embed(timestep, hidden_states.dtype) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = torch.cat((txt_ids, img_ids), dim=0) + + if current_omni_platform.is_npu(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = self.pos_embed(ids) + + # SP: Chunk RoPE embeddings along sequence dimension + if self.parallel_config.sequence_parallel_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_cos, freqs_sin = image_rotary_emb + txt_len = txt_ids.shape[0] + + # Split RoPE into text and image portions + # txt_freqs: (txt_seq_len, head_dim) - keep full for all ranks + # img_freqs: (img_seq_len, head_dim) -> (img_seq_len // SP, head_dim) + txt_freqs_cos = freqs_cos[:txt_len] + txt_freqs_sin = freqs_sin[:txt_len] + img_freqs_cos = freqs_cos[txt_len:] + img_freqs_sin = freqs_sin[txt_len:] + + # Chunk image RoPE for each SP rank + # img_freqs_cos: (img_seq_len // SP, head_dim) + # img_freqs_sin: (img_seq_len // SP, head_dim) + img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank] + img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank] + + # Optionally chunk text RoPE if split_text_embed_in_sp is True + if get_forward_context().split_text_embed_in_sp: + txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank] + txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank] + + # Reconstruct image_rotary_emb with chunked values + # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) + image_rotary_emb = ( + torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), + torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), + ) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + # SP: All-gather output to reconstruct full sequence + if self.parallel_config.sequence_parallel_size > 1: + output = get_sp_group().all_gather(output, dim=1) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # cross attn + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ".to_out.0" in name: + name = name.replace(".to_out.0", ".to_out") + # Handle FeedForward parameter mapping + if ".ff.net." in name: + # Map .ff.net.0.proj -> .ff.w_in + if ".net.0.proj" in name: + name = name.replace(".net.0.proj", ".w_in") + # Map .ff.net.2 -> .ff.w_out + elif ".net.2" in name: + name = name.replace(".net.2", ".w_out") + # Handle FeedForward context parameters + if ".ff_context.net." in name: + # Map .ff_context.net.0.proj -> .ff_context.w_in + if ".net.0.proj" in name: + name = name.replace(".net.0.proj", ".w_in") + # Map .ff_context.net.2 -> .ff_context.w_out + elif ".net.2" in name: + name = name.replace(".net.2", ".w_out") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py new file mode 100644 index 0000000000000000000000000000000000000000..09f409f3139d51cfe38f3cc14c32373d087deb9f --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -0,0 +1,678 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import inspect +import json +import os +import re +from collections.abc import Iterable +from functools import partial +from typing import Any + +import numpy as np +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL +from diffusers.pipelines.longcat_image.system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = init_logger(__name__) + + +def get_longcat_image_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters + defined by single or double quote pairs. + + Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." + >>> print(split_quotation(prompt_en)) + >>> # output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None) -> torch.Tensor: + if type == "text": + assert num_token + if height or width: + logger.warning('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + logger.warning('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknown type {type}, only support "text" or "image".') + # pos_ids = pos_ids[None, :].repeat(batch_size, 1, 1) + return pos_ids + + +def retrieve_timesteps( + scheduler: SchedulerMixin, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def get_prompt_language(prompt): + pattern = re.compile(r"[\u4e00-\u9fff]") + if bool(pattern.search(prompt)): + return "zh" + return "en" + + +class LongCatImagePipeline(nn.Module, CFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self.device = get_local_device() + model = od_config.model + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.text_processor = Qwen2VLProcessor.from_pretrained( + model, subfolder="tokenizer", local_files_only=local_files_only + ) + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + self.transformer = LongCatImageTransformer2DModel(od_config=od_config) + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.prompt_template_encode_prefix = ( + "<|im_start|>system\n" + "As an image captioning expert, generate a descriptive text prompt based on an image content," + " suitable for input to a text-to-image model.<|im_end|>\n" + "<|im_start|>user\n" + ) + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def rewire_prompt(self, prompt, device): + prompt = [prompt] if isinstance(prompt, str) else prompt + all_text = [] + for each_prompt in prompt: + language = get_prompt_language(each_prompt) + if language == "zh": + question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{each_prompt}\n改写后的prompt为:" + else: + question = SYSTEM_PROMPT_EN + f"\nUser Input: {each_prompt}\nRewritten prompt:" + message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + ], + } + ] + text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + all_text.append(text) + + inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device) + + self.text_encoder.to(device) + generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = self.text_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + return output_text + + def _encode_prompt(self, prompt: list[str]) -> torch.Tensor: + batch_all_tokens = [] + + for each_prompt in prompt: + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(each_prompt): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f"{self.tokenizer_max_length} input token nums : {len(all_tokens)}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + batch_all_tokens.append(all_tokens) + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": batch_all_tokens}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + prefix_len = len(prefix_tokens) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + batch_size = text_tokens_and_mask.input_ids.size(0) + prefix_tokens_batch = prefix_tokens.unsqueeze(0).expand(batch_size, -1) + suffix_tokens_batch = suffix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_mask_batch = prefix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + suffix_mask_batch = suffix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + + input_ids = torch.cat((prefix_tokens_batch, text_tokens_and_mask.input_ids, suffix_tokens_batch), dim=-1) + attention_mask = torch.cat((prefix_mask_batch, text_tokens_and_mask.attention_mask, suffix_mask_batch), dim=-1) + + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + prompt_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if prompt_embeds is None and prompt is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt) + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds.to(self.device), text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(self.tokenizer_max_length, self.tokenizer_max_length), + height=height // 2, + width=width // 2, + ).to(device) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device) + latents = latents.to(dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, latent_image_ids + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + "`height` and `width` have to be divisible by " + f"{self.vae_scale_factor * 2} but are {height} and {width}. " + "Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0): + """ + Normalize the combined noise prediction. + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = comb_pred * scale + return noise_pred + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + enable_cfg_renorm: bool | None = True, + cfg_renorm_min: float | None = 0.0, + enable_prompt_rewrite: bool | None = True, + ) -> DiffusionOutput: + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + generator = req.sampling_params.generator or generator + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt is not None + else num_images_per_prompt + ) + enable_prompt_rewrite = req.sampling_params.extra_args.get("enable_prompt_rewrite", enable_prompt_rewrite) + enable_cfg_renorm = req.sampling_params.extra_args.get("enable_cfg_renorm", enable_cfg_renorm) + cfg_renorm_min = req.sampling_params.extra_args.get("cfg_renorm_min", cfg_renorm_min) + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at list one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.device + if enable_prompt_rewrite and prompt is not None: + prompt = self.rewire_prompt(prompt if isinstance(prompt, list) else [prompt], device) + + negative_prompt = "" if negative_prompt is None else negative_prompt + + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + prompt_embeds = prompt_embeds.to(device) + if self.do_classifier_free_guidance: + negative_prompt_embeds = negative_prompt_embeds.to(device) + + # custom partial function with cfg_renorm_min + self.cfg_normalize_function = partial(self.cfg_normalize_function, cfg_renorm_min=cfg_renorm_min) + + # 6. Denoising loop + for i, t in enumerate(timesteps): + if self._interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + if self.do_classifier_free_guidance: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + else: + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=self.do_classifier_free_guidance, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=enable_cfg_renorm, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance) + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..a34a2cca390172cc415ad5c424deff160473c3f6 --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -0,0 +1,707 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import math +import os +import re +from collections.abc import Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import ( + AutoTokenizer, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLProcessor, +) +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import ( + LongCatImageTransformer2DModel, +) +from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import calculate_shift +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = init_logger(__name__) + + +def get_longcat_image_edit_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-processing function for LongCatImageEditPipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + latent_channels = vae_config.get("latent_channels", 16) + + def pre_process_func( + request: OmniDiffusionRequest, + ): + """Pre-process requests for LongCatImageEditPipeline.""" + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None or isinstance(raw_image, list): + raise ValueError( + """Received no image or a list of image. Only a single image is supported by this model.""" + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""" + ) + + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor | np.ndarray, raw_image) + + image_size = image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) + height = request.sampling_params.height or calculated_height + width = request.sampling_params.width or calculated_width + + # Store calculated dimensions in request + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width + + # Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): + image = image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image_processor.resize(image, calculated_height // 2, calculated_width // 2) + image = image_processor.preprocess(image, calculated_height, calculated_width) + + # Store preprocessed image and prompt image in request + prompt["additional_information"]["preprocessed_image"] = image + prompt["additional_information"]["prompt_image"] = prompt_image + request.prompts[i] = prompt + return request + + return pre_process_func + + +def get_longcat_image_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = width if width % 16 == 0 else (width // 16 + 1) * 16 + height = height if height % 16 == 0 else (height // 16 + 1) * 16 + + width = int(width) + height = int(height) + + return width, height + + +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": + assert num_token + if height or width: + logger.warning('The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + logger.warning('The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknown type {type}, only support "text" or "image".') + return pos_ids + + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +class LongCatImageEditPipeline(nn.Module, CFGParallelMixin, SupportImageInput): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self.device = get_local_device() + model = od_config.model + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.text_processor = Qwen2VLProcessor.from_pretrained( + model, subfolder="text_processor", local_files_only=local_files_only + ) + + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + self.transformer = LongCatImageTransformer2DModel(od_config=od_config) + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor_vl = self.text_processor.image_processor + self.latent_channels = self.vae.config.get("latent_channels", 16) + + self.image_token = "<|image_pad|>" + self.prompt_template_encode_prefix = ( + "<|im_start|>system\n" + "As an image editing expert, first analyze the content and attributes of the input image(s). " + "Then, based on the user's editing instructions, clearly and precisely determine how to modify " + "the given image(s), " + "ensuring that only the specified parts are altered and all other aspects remain consistent " + "with the original(s)." + "<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + ) + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def _encode_prompt(self, prompt, image): + raw_vl_input = self.image_processor_vl(images=image, return_tensors="pt") + pixel_values = raw_vl_input["pixel_values"] + image_grid_thw = raw_vl_input["image_grid_thw"] + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(prompt[0]): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": [all_tokens]}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + text = self.prompt_template_encode_prefix + + merge_length = self.image_processor_vl.merge_size**2 + while self.image_token in text: + num_image_tokens = image_grid_thw.prod() // merge_length + text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + text = text.replace("<|placeholder|>", self.image_token) + + prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + prefix_len = prefix_tokens.index(vision_start_token_id) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1) + attention_mask = torch.cat( + (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 + ) + + input_ids = input_ids.unsqueeze(0).to(self.device) + attention_mask = attention_mask.unsqueeze(0).to(self.device) + + pixel_values = pixel_values.to(self.device) + image_grid_thw = image_grid_thw.to(self.device) + + text_output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + ) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: list[str] = None, + image: torch.Tensor | None = None, + num_images_per_prompt: int | None = 1, + prompt_embeds: torch.Tensor | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt, image) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds, text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + prompt_embeds_length, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + image_latents, image_latents_ids = None, None + + if image is not None: + image = image.to(device=self.device, dtype=dtype) + + if image.shape[1] != self.vae.config.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + + image_latents_ids = prepare_pos_ids( + modality_id=2, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device, dtype=torch.float64) + + shape = (batch_size, num_channels_latents, height, width) + latents_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latents_ids, image_latents_ids + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + "`height` and `width` have to be divisible by " + f"{self.vae_scale_factor * 2} but are {height} and {width}. " + "Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None: + if isinstance(prompt, str): + pass + elif isinstance(prompt, list) and len(prompt) == 1: + pass + else: + raise ValueError( + f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + def forward( + self, + req: OmniDiffusionRequest, + image: PIL.Image.Image | torch.Tensor | None = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + ): + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # type: ignore # Why it is list[torch.Tensor] in OmniTokenInputs or OmniEmbedsPrompt? Doesn't make sense + + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt is not None + else num_images_per_prompt + ) + generator = req.sampling_params.generator or generator + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + + if prompt is not None: + batch_size = 1 if isinstance(prompt, str) else len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if not isinstance(first_prompt, str) and "preprocessed_image" in ( + additional_information := first_prompt.get("additional_information", {}) + ): + prompt_image = additional_information.get("prompt_image") + image = additional_information.get("preprocessed_image") + calculated_height = additional_information.get("calculated_height", height) + calculated_width = additional_information.get("calculated_width", width) + else: + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2) + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + + self.check_inputs( + prompt, + calculated_height, + calculated_width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt + ) + + if guidance_scale > 1: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + image=prompt_image, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + device = self.device + + # Prepare latent variables + num_channels_latents = 16 + latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + calculated_height, + calculated_width, + prompt_embeds.dtype, + prompt_embeds.shape[1], + device, + generator, + latents, + ) + + # Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + self._num_timesteps = len(timesteps) + + guidance = None + + if image is not None: + latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0) + else: + latent_image_ids = latents_ids + + for i, t in enumerate(timesteps): + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + do_true_cfg = guidance_scale > 1 + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + else: + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + output_slice=image_seq_len, + ) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/ovis_image/__init__.py b/vllm_omni/diffusion/models/ovis_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7beae576ae7d469064216b7bd928340d7799a8 --- /dev/null +++ b/vllm_omni/diffusion/models/ovis_image/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Ovis Image 7B diffusion model components.""" + +from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import ( + OvisImageTransformer2DModel, +) +from vllm_omni.diffusion.models.ovis_image.pipeline_ovis_image import ( + OvisImagePipeline, + get_ovis_image_post_process_func, +) + +__all__ = [ + "OvisImagePipeline", + "OvisImageTransformer2DModel", + "get_ovis_image_post_process_func", +] diff --git a/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py b/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6cf6b0ccffe42b24eae8396a1a159483f1bcf6 --- /dev/null +++ b/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py @@ -0,0 +1,544 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from diffusers.utils import is_torch_npu_available +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +logger = init_logger(__name__) + + +class OvisImageAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: bool | None = None, + pre_only: bool = False, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + disable_tp=True, + bias=bias, + ) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + + self.add_kv_proj = QKVParallelLinear( + hidden_size=self.added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + disable_tp=True, + bias=added_proj_bias, + ) + + self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias) + + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.to_qkv(hidden_states) + + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = encoder_qkv.chunk(3, dim=-1) + + encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb # [S, D/2] + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + hidden_states = self.attn( + query, + key, + value, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class OvisImageSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim * 2) + self.act_mlp = nn.SiLU() + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = OvisImageAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states, mlp_hidden_gate = torch.split( + self.proj_mlp(norm_hidden_states), [self.mlp_hidden_dim, self.mlp_hidden_dim], dim=-1 + ) + mlp_hidden_states = self.act_mlp(mlp_hidden_gate) * mlp_hidden_states + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class OvisImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = OvisImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class OvisImagePosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + freqs_cis = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + use_real=False, + freqs_dtype=freqs_dtype, + ) + cos_out.append(freqs_cis.real) + sin_out.append(freqs_cis.imag) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class OvisImageTransformer2DModel(nn.Module): + """ + The Transformer model introduced in Ovis-Image. + + Reference: https://github.com/AIDC-AI/Ovis-Image + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `6`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `27`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `2048`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + + def __init__( + self, + od_config: OmniDiffusionConfig, + patch_size: int = 1, + in_channels: int = 64, + out_channels: int | None = 64, + num_layers: int = 6, + num_single_layers: int = 27, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 2048, + axes_dims_rope: tuple[int] = (16, 56, 56), + ): + super().__init__() + model_config = od_config.tf_model_config + num_layers = model_config.num_layers + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.pos_embed = OvisImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + + self.context_embedder_norm = RMSNorm(joint_attention_dim, eps=1e-6) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + OvisImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + OvisImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`OvisImageTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + img_ids: (`torch.Tensor`): + The position ids for image tokens. + txt_ids (`torch.Tensor`): + The position ids for text tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + hidden_states = self.x_embedder(hidden_states) + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype) * 1000 + + timesteps_proj = self.time_proj(timestep) + temb = self.timestep_embedder(timesteps_proj.to(device=hidden_states.device, dtype=hidden_states.dtype)) + + encoder_hidden_states = self.context_embedder_norm(encoder_hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # self attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # cross attn + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + # we need to load the buffers for beta and eps (XIELU) + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py new file mode 100644 index 0000000000000000000000000000000000000000..963f1c483b35b2044dd7e6a202399a90922c336f --- /dev/null +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -0,0 +1,741 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +from collections.abc import Callable, Iterable +from typing import Any + +import numpy as np +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import Qwen2TokenizerFast, Qwen3Model +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +logger = init_logger(__name__) + + +def get_ovis_image_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timetemps + from the scheduler after the call. Handles custom timeteps. Any kwargs will be supplied to `scheduler.set_timeteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"the current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accepts_timesteps = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"the current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OvisImagePipeline(nn.Module, CFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self._execution_device = get_local_device() + model = od_config.model + local_files_only = os.path.exists(model) + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + + self.text_encoder = Qwen3Model.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self._execution_device + ) + + self.tokenizer = Qwen2TokenizerFast.from_pretrained( + model, subfolder="tokenizer", local_files_only=local_files_only + ) + + self.transformer = OvisImageTransformer2DModel(od_config=od_config) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.tokenizer_max_length = 1024 + self.system_prompt = """Describe the image by detailing the color, quantity, text, shape, size, texture, spatial + relationships of the objects and background: """ + self.user_prompt_begin_id = 28 + self.tokenizer_max_length = 256 + self.user_prompt_begin_id + self.default_sample_size = 128 + + def _get_messages( + self, + prompt: str | list[str] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + messages = [] + + for each_prompt in prompt: + message = [ + { + "role": "user", + "content": self.system_prompt + each_prompt, + } + ] + message = self.tokenizer.apply_chat_template( + message, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + messages.append(message) + + return messages + + def _get_ovis_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + messages = self._get_messages(prompt) + + batch_size = len(messages) + + tokens = self.tokenizer( + messages, + padding="max_length", + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors="pt", + add_special_tokens=False, + ) + + input_ids = tokens.input_ids.to(device=device) + attention_mask = tokens.attention_mask.to(device=device) + + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + prompt_embeds = outputs.last_hidden_state + prompt_embeds = prompt_embeds * attention_mask[..., None] + prompt_embeds = prompt_embeds[:, self.user_prompt_begin_id :, :] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch.device + num_images_per_prompt: (`int`): + number of images that should be generated per prompt + prompt_embeds: (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, text embeddings will be generated from `prompt` input argument. + """ + + device = device or self._execution_device + + if prompt_embeds is None: + prompt_embeds = self._get_ovis_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + text_ids[..., 1] = text_ids[..., 1] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids[..., 2] = text_ids[..., 2] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids = text_ids.to(device=device, dtype=dtype) + return prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"""`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are + {height} and {width}. Dimension will be resized accordingly""" + ) + + # if callback_on_step_end_tensor_inputs is not None and not all( + # k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + # ): + # raise ValueError( + # f"""`callback_on_step_end_tensor_inputs` has to contain the following keys: + # {self._callback_tensor_inputs.keys()}""" + # ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list[str]` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: " + f"{negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` has to be less than or equal to 256 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channel_latents, height, width): + latents = latents.view(batch_size, num_channel_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channel_latents * 4) + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2 + height = int(2 * (int(height) // (vae_scale_factor * 2))) + width = int(2 * (int(width) // (vae_scale_factor * 2))) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + def prepare_latents( + self, + batch_size, + num_channel_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = int(2 * (int(height) // (self.vae_scale_factor * 2))) + width = int(2 * (int(width) // (self.vae_scale_factor * 2))) + + shape = (batch_size, num_channel_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channel_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + self._execution_device, + sigmas=sigmas, + mu=mu, + ) + return timesteps, num_inference_steps + + def diffuse( + self, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + text_ids: torch.Tensor, + negative_text_ids: torch.Tensor, + latent_image_ids: torch.Tensor, + do_true_cfg: bool, + guidance_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor: + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + timesteps: Diffusion timesteps + prompt_embeds: Positive prompt embeddings + negative_prompt_embeds: Negative prompt embeddings + text_ids: Position IDs for positive text + negative_text_ids: Position IDs for negative text + latent_image_ids: Position IDs for image latents + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ + self.scheduler.set_begin_index(0) + + for i, t in enumerate(timesteps): + if self.interrupt: + break + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + guidance_scale: float = 5.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ) -> DiffusionOutput: + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + guidance_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `guidance_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] or `tuple`: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + # Steps: + # 1. Check Inputs + # 2. encode prompts + # 4. Prepare latents + # 5. Prepare timesteps + # 6. diffusion latents + # 7. decode latents + # 8. post process outputs + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + device = self._execution_device + device = self._execution_device + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + negative_text_ids = None + if do_classifier_free_guidance: + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channel_latents = self.transformer.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channel_latents=num_channel_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 5. Prepare timesteps + + image_seq_len = latents.shape[1] + timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, image_seq_len) + + # num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # 6. Denoising loop using diffuse method + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds if do_classifier_free_guidance else None, + text_ids=text_ids, + negative_text_ids=negative_text_ids if do_classifier_free_guidance else None, + latent_image_ids=latent_image_ids, + do_true_cfg=do_classifier_free_guidance, + guidance_scale=guidance_scale, + cfg_normalize=False, + ) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/qwen_image/__init__.py b/vllm_omni/diffusion/models/qwen_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b823ec75dc66f377e04f353c18cc038c843dfad --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Qwen Image diffusion model components.""" + +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import ( + QwenImagePipeline, + get_qwen_image_post_process_func, +) +from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + QwenImageTransformer2DModel, +) + +__all__ = [ + "QwenImageCFGParallelMixin", + "QwenImagePipeline", + "QwenImageTransformer2DModel", + "get_qwen_image_post_process_func", +] diff --git a/vllm_omni/diffusion/models/qwen_image/autoencoder_kl_qwenimage.py b/vllm_omni/diffusion/models/qwen_image/autoencoder_kl_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5c9ef7069ac50cccee4972ce4b063b3d0b3e68 --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/autoencoder_kl_qwenimage.py @@ -0,0 +1,1054 @@ +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - Paper: https://huggingface.co/papers/2503.20314 + +# Copied from diffusers to avoid version coupling. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.models.activations import get_activation +from diffusers.models.autoencoders.vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + input_channels=3, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: str | None = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temporal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temporal_upsample=[False, True, True], + dropout=0.0, + input_channels=3, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temporal_upsample = temporal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temporal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: tuple[int, ...] = (1, 2, 4, 4), + num_res_blocks: int = 2, + attn_scales: list[float] = [], + temperal_downsample: list[bool] = [False, True, True], + dropout: float = 0.0, + input_channels: int = 3, + latents_mean: list[float] = ([-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, + 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, + 0.2503, -0.2921]), + latents_std: list[float] = ([2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]), + ) -> None: + # fmt: on + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temporal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, + self.temperal_downsample, dropout, input_channels + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout, input_channels + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..9a882f7bf0d7ef3225c42bc5f2f456448b4778db --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CFG Parallel Mixin for Qwen Image series +Shared by +- QwenImagePipeline +- QwenImageEditPipeline +- QwenImageEditPlusPipeline +- QwenImageLayeredPipeline +""" + +import logging +from typing import Any + +import torch + +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size + +logger = logging.getLogger(__name__) + + +class QwenImageCFGParallelMixin(CFGParallelMixin): + """ + Base Mixin class for Qwen Image pipelines providing shared CFG methods. + """ + + def diffuse( + self, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + latents: torch.Tensor, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + negative_txt_seq_lens: torch.Tensor, + timesteps: torch.Tensor, + do_true_cfg: bool, + guidance: torch.Tensor, + true_cfg_scale: float, + image_latents: torch.Tensor | None = None, + cfg_normalize: bool = True, + additional_transformer_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + """ + Diffusion loop with optional classifier-free guidance. + + Args: + prompt_embeds: Positive prompt embeddings + prompt_embeds_mask: Mask for positive prompt + negative_prompt_embeds: Negative prompt embeddings + negative_prompt_embeds_mask: Mask for negative prompt + latents: Noise latents to denoise + img_shapes: Image shape information + txt_seq_lens: Text sequence lengths for positive prompts + negative_txt_seq_lens: Text sequence lengths for negative prompts + timesteps: Diffusion timesteps + do_true_cfg: Whether to apply CFG + guidance: Guidance scale tensor + true_cfg_scale: CFG scale factor + image_latents: Conditional image latents for editing (default: None) + cfg_normalize: Whether to normalize CFG output (default: True) + additional_transformer_kwargs: Extra kwargs to pass to transformer (default: None) + + Returns: + Denoised latents + """ + self.scheduler.set_begin_index(0) + self.transformer.do_true_cfg = do_true_cfg + additional_transformer_kwargs = additional_transformer_kwargs or {} + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + + # Broadcast timestep to match batch size + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + # Concatenate image latents with noise latents if available (for editing pipelines) + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + **additional_transformer_kwargs, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + **additional_transformer_kwargs, + } + else: + negative_kwargs = None + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + true_cfg_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + output_slice, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + return latents + + def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool): + """ + Validate whether CFG parallel is properly configured for the current generation request. + + When CFG parallel is enabled (cfg_parallel_world_size > 1), this method verifies that the necessary + conditions are met for correct parallel execution. If validation fails, a warning is + logged to help identify configuration issues. + + Args: + true_cfg_scale: The classifier-free guidance scale value. Must be > 1 for CFG to + have an effect. + has_neg_prompt: Whether negative prompts or negative prompt embeddings are provided. + Required for CFG to perform unconditional prediction. + + Returns: + True if CFG parallel is disabled or all validation checks pass, False otherwise. + + Note: + When CFG parallel is disabled (world_size == 1), this method always returns True + as no parallel-specific validation is needed. + """ + if get_classifier_free_guidance_world_size() == 1: + return True + + if true_cfg_scale <= 1: + logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.") + return False + + if not has_neg_prompt: + logger.warning( + "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings." + ) + return False + return True diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..d85d98b5bf5ea924cbd046f118d3eab1e301266e --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -0,0 +1,718 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +import json +import logging +import math +import os +from collections.abc import Iterable +from typing import Any + +import numpy as np +import torch +import torch.distributed +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import ( + AutoencoderKLQwenImage, +) +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) +from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + QwenImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +def get_qwen_image_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: torch.Tensor | tuple[torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class QwenImagePipeline(nn.Module, QwenImageCFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self.device = get_local_device() + model = od_config.model + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) + + self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + + self.stage = None + + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. + # This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + # self.image_processor = VaeImageProcessor( + # vae_scale_factor=self.vae_scale_factor * 2 + # ) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" # noqa: E501 + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} " + f"but are {height} and {width}. Dimensions will be resized accordingly" + ) + + # if callback_on_step_end_tensor_inputs is not None and not all( + # k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + # ): + # raise ValueError( + # f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, + # but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + # ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. " + "Make sure to generate `prompt_embeds_mask` from the same text encoder " + "that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. " + "Make sure to generate `negative_prompt_embeds_mask` from the same text encoder " + "that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + dtype: torch.dtype | None = None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + max_length=self.tokenizer_max_length + drop_idx, + padding=True, + truncation=True, + return_tensors="pt", + ).to(self.device) + # print(f"attention mask: {txt_tokens.attention_mask}") + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ) -> torch.Tensor: + # generator=torch.Generator(device="cuda").manual_seed(42) + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + # image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + mu=mu, + ) + return timesteps, num_inference_steps + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ) -> DiffusionOutput: + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + # 1. check inputs + # 2. encode prompts + # 3. prepare latents and timesteps + # 4. diffusion process + # 5. decode latents + # 6. post-process outputs + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + num_channels_latents = self.transformer.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + img_shapes = [ + [ + ( + 1, + height // self.vae_scale_factor // 2, + width // self.vae_scale_factor // 2, + ) + ] + ] * batch_size + + timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) + # num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.guidance_embeds: + guidance = torch.full([1], guidance_scale, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # print inputp params + + latents = self.diffuse( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + latents, + img_shapes, + txt_seq_lens, + negative_txt_seq_lens, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + image_latents=None, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, + ) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + # processed_image = self.image_processor.postprocess(image, output_type=output_type) + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..78fd92c9d5b704f9c0bb234b490acf34ccff4f6b --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -0,0 +1,810 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +import json +import logging +import math +import os +from collections.abc import Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import ( + AutoencoderKLQwenImage, +) +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift +from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + QwenImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +def get_qwen_image_edit_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-processing function for QwenImageEditPipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True) + latent_channels = vae_config.get("z_dim", 16) + + def pre_process_func( + request: OmniDiffusionRequest, + ): + """Pre-process requests for QwenImageEditPipeline.""" + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + # Only handles single image + if raw_image is None or isinstance(raw_image, list): + raise ValueError( + """Received no image or a list of image. Only a single image is supported by this model.""" + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""" + ) + + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor | np.ndarray, raw_image) + + image_size = image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = request.sampling_params.height or calculated_height + width = request.sampling_params.width or calculated_width + + # Ensure dimensions are multiples of vae_scale_factor * 2 + multiple_of = vae_scale_factor * 2 + height = height // multiple_of * multiple_of + width = width // multiple_of * multiple_of + + # Store calculated dimensions in request + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width + + # Preprocess image + if image is not None and not ( + isinstance(image, torch.Tensor) and len(image.shape) > 1 and image.shape[1] == latent_channels + ): + image = image_processor.resize(image, height, width) + prompt_image = image + image = image_processor.preprocess(image, height, width) + image = image.unsqueeze(2) + + # Store preprocessed image and prompt image in request + prompt["additional_information"]["preprocessed_image"] = image + prompt["additional_information"]["prompt_image"] = prompt_image + request.prompts[i] = prompt + return request + + return pre_process_func + + +def get_qwen_image_edit_post_process_func( + od_config: OmniDiffusionConfig, +): + """Post-processing function for QwenImageEditPipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_dimensions(target_area: float, ratio: float): + """Calculate width and height from target area and aspect ratio.""" + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "argmax" +): + """Retrieve latents from VAE encoder output.""" + if hasattr(encoder_output, "latent_dist"): + return ( + encoder_output.latent_dist.mode() + if sample_mode == "argmax" + else encoder_output.latent_dist.sample(generator) + ) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class QwenImageEditPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + self.device = get_local_device() + model = od_config.model + + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + + self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) + self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.processor = Qwen2VLProcessor.from_pretrained( + model, subfolder="processor", local_files_only=local_files_only + ) + + self.stage = None + + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_convert_rgb=True) + self.tokenizer_max_length = 1024 + # Edit prompt template - different from generation template + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" # noqa: E501 + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + def check_inputs( + self, + prompt, + height, + width, + image=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} " + f"but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. " + "Make sure to generate `prompt_embeds_mask` from the same text encoder " + "that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. " + "Make sure to generate `negative_prompt_embeds_mask` from the same text encoder " + "that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + image: torch.Tensor | None = None, + dtype: torch.dtype | None = None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(self.device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) + + return prompt_embeds, encoder_attention_mask + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + image: PIL.Image.Image | torch.Tensor | None = None, + dtype: torch.dtype | None = None, + ): + """Get prompt embeddings with image support for editing.""" + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + # Use processor to handle both text and image inputs + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(self.device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + image: torch.Tensor | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + mu=mu, + ) + return timesteps, num_inference_steps + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + image: PIL.Image.Image | torch.Tensor | None = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ) -> DiffusionOutput: + """Forward pass for image editing.""" + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + + # Get preprocessed image from request (pre-processing is done in DiffusionEngine) + if not isinstance(first_prompt, str) and "preprocessed_image" in ( + additional_information := first_prompt.get("additional_information", {}) + ): + prompt_image = additional_information.get("prompt_image") + image = additional_information.get("preprocessed_image") + calculated_height = additional_information.get("calculated_height") + calculated_width = additional_information.get("calculated_width") + height = req.sampling_params.height + width = req.sampling_params.width + else: + # fallback to run pre-processing in pipeline (debug only) + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + # 1. check inputs + # 2. encode prompts + # 3. prepare latents and timesteps + # 4. diffusion process + # 5. decode latents + # 6. post-process outputs + self.check_inputs( + prompt, + height, + width, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + image=prompt_image, # Use resized image for prompt encoding + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + image=prompt_image, # Use same resized image for negative prompt encoding + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + num_channels_latents = self.transformer.in_channels // 4 + # random noise latents, and image latents encoded by vae + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.guidance_embeds: + guidance = torch.full([1], guidance_scale, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + latents = self.diffuse( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + latents, + img_shapes, + txt_seq_lens, + negative_txt_seq_lens, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + image_latents=image_latents, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, + ) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..00e775802971afaa7b54d12aeb4b27d08e364e23 --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -0,0 +1,769 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import logging +import os +from collections.abc import Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import ( + AutoencoderKLQwenImage, +) +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( + calculate_dimensions, + retrieve_latents, + retrieve_timesteps, +) +from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + QwenImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + + +def get_qwen_image_edit_plus_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-processing function for QwenImageEditPlusPipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True) + latent_channels = vae_config.get("z_dim", 16) + + def pre_process_func( + request: OmniDiffusionRequest, + ): + """Pre-process requests for QwenImageEditPlusPipeline.""" + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + # Handle single image or list of images + if raw_image is None: + continue + + if not isinstance(raw_image, list): + raw_image = [raw_image] + image = [ + PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im) + for im in raw_image + ] + + # Calculate dimensions based on first image + image_size = image[0].size + calculated_width, calculated_height = calculate_dimensions(VAE_IMAGE_SIZE, image_size[0] / image_size[1]) + height = request.sampling_params.height or calculated_height + width = request.sampling_params.width or calculated_width + + # Ensure dimensions are multiples of vae_scale_factor * 2 + multiple_of = vae_scale_factor * 2 + height = height // multiple_of * multiple_of + width = width // multiple_of * multiple_of + + # Store calculated dimensions in request + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width + + # Preprocess images into condition_images (for prompt encoding) and vae_images (for VAE encoding) + condition_images = [] + vae_images = [] + condition_image_sizes = [] + vae_image_sizes = [] + + for img in image: + if isinstance(img, torch.Tensor) and len(img.shape) > 1 and img.shape[1] == latent_channels: + # Already a latent tensor + continue + + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + + condition_images.append(image_processor.resize(img, condition_height, condition_width)) + vae_images.append(image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + # Store preprocessed images in request + prompt["additional_information"]["condition_images"] = condition_images + prompt["additional_information"]["vae_images"] = vae_images + prompt["additional_information"]["condition_image_sizes"] = condition_image_sizes + prompt["additional_information"]["vae_image_sizes"] = vae_image_sizes + request.prompts[i] = prompt + return request + + return pre_process_func + + +def get_qwen_image_edit_plus_post_process_func( + od_config: OmniDiffusionConfig, +): + """Post-processing function for QwenImageEditPlusPipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +class QwenImageEditPlusPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + self.device = get_local_device() + model = od_config.model + + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + + self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) + self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.processor = Qwen2VLProcessor.from_pretrained( + model, subfolder="processor", local_files_only=local_files_only + ) + + self.stage = None + + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_convert_rgb=True) + self.tokenizer_max_length = 1024 + # Edit prompt template - different from generation template, supports multiple images + self.prompt_template_encode = ( + "<|im_start|>system\nDescribe the key features of the input image " + "(color, shape, size, texture, objects, background), then explain how the user's " + "text instruction should alter or modify the image. Generate a new image that meets " + "the user's requirements while maintaining consistency with the original input where " + "appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ) + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + def check_inputs( + self, + prompt, + height, + width, + image=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} " + f"but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. " + "Make sure to generate `prompt_embeds_mask` from the same text encoder " + "that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. " + "Make sure to generate `negative_prompt_embeds_mask` from the same text encoder " + "that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str], + image: list[torch.Tensor] | torch.Tensor | None = None, + dtype: torch.dtype | None = None, + ): + """Get prompt embeddings with support for multiple images.""" + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Build image prompt template for multiple images + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + # Use processor to handle both text and image inputs + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(self.device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + image: list[torch.Tensor] | torch.Tensor | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor` or `list[torch.Tensor]`, *optional*): + image(s) to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + images, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if images is not None: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + all_image_latents.append(image_latents) + # Concatenate all image latents along dimension 1 + image_latents = torch.cat(all_image_latents, dim=1) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + mu=mu, + ) + return timesteps, num_inference_steps + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor | None = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ) -> DiffusionOutput: + """Forward pass for image editing with support for multiple images.""" + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + + # Get preprocessed images from request (pre-processing is done in DiffusionEngine) + if ( + not isinstance(first_prompt, str) + and "vae_images" in (additional_information := first_prompt.get("additional_information", {})) + and "condition_images" in additional_information + ): + condition_images = additional_information.get("condition_images") + vae_images = additional_information.get("vae_images") + condition_image_sizes = additional_information.get("condition_image_sizes") + vae_image_sizes = additional_information.get("vae_image_sizes") + calculated_height = additional_information.get("calculated_height") + calculated_width = additional_information.get("calculated_width") + height = req.sampling_params.height + width = req.sampling_params.width + else: + # fallback to run pre-processing in pipeline (debug only) + if image is None: + raise ValueError("Image is required for QwenImageEditPlusPipeline") + + if not isinstance(image, list): + image = [image] + + image_size = image[0].size + calculated_width, calculated_height = calculate_dimensions(VAE_IMAGE_SIZE, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + condition_images = [] + vae_images = [] + condition_image_sizes = [] + vae_image_sizes = [] + + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + # 1. check inputs + # 2. encode prompts + # 3. prepare latents and timesteps + # 4. diffusion process + # 5. decode latents + # 6. post-process outputs + self.check_inputs( + prompt, + height, + width, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + image=condition_images, # Use condition images for prompt encoding + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + image=condition_images, # Use same condition images for negative prompt encoding + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + num_channels_latents = self.transformer.in_channels // 4 + # random noise latents, and image latents encoded by vae + latents, image_latents = self.prepare_latents( + vae_images, # Use VAE images for latent preparation + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + # img_shapes includes shapes for output image and all input images + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.guidance_embeds: + guidance = torch.full([1], guidance_scale, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + latents = self.diffuse( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + latents, + img_shapes, + txt_seq_lens, + negative_txt_seq_lens, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + image_latents=image_latents, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, + ) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py new file mode 100644 index 0000000000000000000000000000000000000000..d200764ebfb87502e5f997e7240bcb579dbb9563 --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -0,0 +1,846 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +import json +import logging +import math +import os +from collections.abc import Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import ( + AutoencoderKLQwenImage, +) +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) +from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + QwenImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +# Interface called in diffusion engine +def get_qwen_image_layered_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-processing function for QwenImageLayeredPipeline.""" + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8 + latent_channels = vae_config.get("z_dim", 16) + + # QwenImage latents are turned into 2x2 patches and packed. + # This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied + # by the patch size to account for this + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def pre_process_func( + request: OmniDiffusionRequest, + ): + """Pre-process requests for QwenImageLayeredPipeline.""" + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None or isinstance(raw_image, list): + raise ValueError( + """Received no image or a list of image. Only a single image is supported by this model.""" + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""" + ) + + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor | np.ndarray, raw_image) + + # 1. calculate dimensions + image_size = image.size + assert request.sampling_params.resolution in [640, 1024], ( + f"resolution must be either 640 or 1024, but got {request.sampling_params.resolution}" + ) + calculated_width, calculated_height = calculate_dimensions( + request.sampling_params.resolution * request.sampling_params.resolution, image_size[0] / image_size[1] + ) + height = calculated_height + width = calculated_width + + multiple_of = vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # Store calculated dimensions in request + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width + + # 2. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): + image = image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + # image = image.to(dtype=self.text_encoder.dtype) # do it later + + # Store preprocessed image and prompt image in request + prompt["additional_information"]["preprocessed_image"] = image + prompt["additional_information"]["prompt_image"] = prompt_image + request.prompts[i] = prompt + return request + + return pre_process_func + + +# Copied from diffusers to avoid version coupling. +# Upstream code merged on 2025-12-18. +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class QwenImageLayeredPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): + color_format = "RGBA" + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.device = get_local_device() + model = od_config.model + # Check if model is a local path + local_files_only = os.path.exists(model) + + # modules keep same as transformers & diffusers + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.processor = Qwen2VLProcessor.from_pretrained( + model, subfolder="processor", local_files_only=local_files_only + ) + + # modules re-implemented for vLLM-Omni + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) + + # Pipeline configuration & processing parameters + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. + # This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied + # by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.vl_processor = self.processor + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = ( + "<|im_start|>system\nDescribe the image by detailing the color, " + "shape, size, texture, quantity, text, spatial relationships of the objects and " + "background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ) + self.prompt_template_encode_start_idx = 34 + self.image_caption_prompt_cn = ( + """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# """ + """图像标注器\n你是一个专业的图像标注器。请基于输入图像,撰写图注:\n1. +使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n2. 通过加入以下内容,丰富图注细节:\n """ + """- 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n - +对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n - 环境细节:例如天气、""" + """光照、颜色、纹理、气氛等\n - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在""" + """图注中强调\n3. +保持真实性与准确性:\n - 不要使用笼统的描述\n - +描述图像中所有可见的信息,但不要加入没有在图像中出现""" + """的内容\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + ) + self.image_caption_prompt_en = ( + """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n""" + """<|im_start|>user\n# Image Annotator\nYou are a professional +image annotator. Please write an image caption based on the input image:\n1. Write the caption using natural, +descriptive language without structured formats or rich text.\n2. Enrich caption details by including: \n - Object +attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n - Vision Relations +between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action +relations, comparative relations, causal relations, and so on\n - Environmental details, such as weather, lighting, +colors, textures, atmosphere, and so on\n - Identify the text clearly visible in the image, without translation or +explanation, and highlight it in the caption with quotation marks\n3. Maintain authenticity and accuracy:\n - Avoid +generalizations\n - Describe all visible information in the image, while do not add information not explicitly shown in +the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + ) + self.default_sample_size = 128 + + self.stage = None + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} " + f"but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate" + " `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make" + " sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to " + "generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + return_tensors="pt", + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + layers, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = ( + batch_size, + layers + 1, + num_channels_latents, + height, + width, + ) ### the generated first image is combined image + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = image_latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) -> (b, f, c, h, w) + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width, 1 + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, layers + 1) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + def get_image_caption(self, prompt_image, use_en_prompt=True, device=None): + if use_en_prompt: + prompt = self.image_caption_prompt_en + else: + prompt = self.image_caption_prompt_cn + model_inputs = self.vl_processor( + text=prompt, + images=prompt_image, + padding=True, + return_tensors="pt", + ).to(device) + generated_ids = self.text_encoder.generate(**model_inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = self.vl_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + return output_text.strip() + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers): + latents = latents.view(batch_size, layers, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape(batch_size, layers * (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, layers, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, layers + 1, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + + latents = latents.reshape(batch_size, layers + 1, channels // (2 * 2), height, width) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + image: PIL.Image.Image | torch.Tensor | None = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + true_cfg_scale: float = 4.0, + layers: int | None = 4, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + attention_kwargs: dict[str, Any] | None = None, + max_sequence_length: int = 512, + resolution: int = 640, + cfg_normalize: bool = False, + use_en_prompt: bool = False, + ) -> DiffusionOutput: + """Forward pass for image layered.""" + + # 1. Get preprocessed image from request (pre-processing is done in DiffusionEngine) + # Override parameters from request if provided + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + + layers = req.sampling_params.layers if req.sampling_params.layers is not None else layers + resolution = req.sampling_params.resolution if req.sampling_params.resolution is not None else resolution + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + cfg_normalize = ( + req.sampling_params.cfg_normalize if req.sampling_params.cfg_normalize is not None else cfg_normalize + ) + use_en_prompt = ( + req.sampling_params.use_en_prompt if req.sampling_params.use_en_prompt is not None else use_en_prompt + ) + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + if not isinstance(first_prompt, str) and "preprocessed_image" in ( + additional_information := first_prompt.get("additional_information", {}) + ): + prompt_image = additional_information.get("prompt_image") + image = additional_information.get("preprocessed_image") + image = image.to(dtype=self.text_encoder.dtype) # Now we get the type + calculated_height = additional_information.get("calculated_height") + calculated_width = additional_information.get("calculated_width") + height = req.sampling_params.height + width = req.sampling_params.width + else: + # fallback to run pre-processing in pipeline (debug only) + image_size = image[0].size if isinstance(image, list) else image.size + assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}" + calculated_width, calculated_height = calculate_dimensions( + resolution * resolution, image_size[0] / image_size[1] + ) + height = calculated_height + width = calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + image = image.to(dtype=self.text_encoder.dtype) + + # 2. check inputs + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 3. encode prompot & negative prompt + if prompt is None or prompt == "" or prompt == " ": + prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=self.device) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free " + f"guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=self.device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=self.device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + layers, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + img_shapes = [ + [ + *[ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2) + for _ in range(layers + 1) + ], + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + base_seqlen = 256 * 256 / 16 / 16 + mu = (image_latents.shape[1] / base_seqlen) ** 0.5 + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + self.device, + sigmas=sigmas, + mu=mu, + ) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=self.device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + is_rgb = torch.tensor([0] * batch_size).to(device=self.device, dtype=torch.long) + + latents = self.diffuse( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + latents, + img_shapes, + txt_seq_lens, + negative_txt_seq_lens, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + image_latents=image_latents, + cfg_normalize=cfg_normalize, + additional_transformer_kwargs={ + "return_dict": False, + "additional_t_cond": is_rgb, + "attention_kwargs": self.attention_kwargs, + }, + ) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, layers, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + + b, c, f, h, w = latents.shape + + latents = latents[:, :, 1:] # remove the first frame as it is the origin input + + latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) + + image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w + + image = image.squeeze(2) + # Maybe extract post process in the future + image = self.image_processor.postprocess(image, output_type=output_type) + images = [] + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) + + return DiffusionOutput(output=images) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3cc6cbbb67b38c0c10b7dff991cef849e62e902 --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -0,0 +1,1084 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools +from collections.abc import Iterable +from math import prod +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# TODO replace this with vLLM implementation +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionMetadata, +) +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.cache.base import CachedTransformer +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, +) +from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +logger = init_logger(__name__) + + +class ImageRopePrepare(nn.Module): + """Prepares image hidden_states and RoPE embeddings for sequence parallel. + + This module encapsulates the input linear projection and RoPE computation. + Similar to Z-Image's UnifiedPrepare, this creates a module boundary where + _sp_plan can shard outputs via split_output=True. + + The key insight is that hidden_states and vid_freqs must be sharded together + to maintain dimension alignment for RoPE computation in attention layers. + + Note: Our _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism). + """ + + def __init__(self, img_in: nn.Linear, pos_embed: nn.Module): + super().__init__() + self.img_in = img_in + self.pos_embed = pos_embed + + def forward( + self, + hidden_states: torch.Tensor, + img_shapes: list[tuple[int, int, int]], + txt_seq_lens: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Prepare hidden_states and RoPE for SP. + + Args: + hidden_states: [batch, img_seq_len, channels] + img_shapes: List of (frame, height, width) tuples + txt_seq_lens: List of text sequence lengths + + Returns: + hidden_states: Processed hidden states [batch, img_seq_len, dim] + vid_freqs: Image RoPE frequencies [img_seq_len, rope_dim] + txt_freqs: Text RoPE frequencies [txt_seq_len, rope_dim] + + Note: _sp_plan will shard hidden_states and vid_freqs via split_output=True + txt_freqs is kept replicated for dual-stream attention + """ + # Apply input projection + hidden_states = self.img_in(hidden_states) + + # Compute RoPE embeddings + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + vid_freqs, txt_freqs = image_rotary_emb + + return hidden_states, vid_freqs, txt_freqs + + +class ModulateIndexPrepare(nn.Module): + """Prepares modulate_index for sequence parallel when zero_cond_t is enabled. + + This module encapsulates the creation of modulate_index tensor, which is used + to select different conditioning parameters (shift/scale/gate) for different + token positions in image editing tasks. + + Similar to Z-Image's UnifiedPrepare and ImageRopePrepare, this creates a module + boundary where _sp_plan can shard the output via split_output=True. + + The modulate_index must be sharded along the sequence dimension to match the + sharded hidden_states in SP mode. + + Note: Our _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism). + """ + + def __init__(self, zero_cond_t: bool = False): + super().__init__() + self.zero_cond_t = zero_cond_t + + def forward( + self, + timestep: torch.Tensor, + img_shapes: list[list[tuple[int, int, int]]], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Prepare timestep and modulate_index for SP. + + Args: + timestep: Timestep tensor [batch] + img_shapes: List of image shape tuples per batch item. + Each item is a list of (frame, height, width) tuples. + For edit models: [[source_shape], [target_shape1, target_shape2, ...]] + + Returns: + timestep: Doubled timestep if zero_cond_t, else original [batch] or [2*batch] + modulate_index: Token condition index [batch, seq_len] if zero_cond_t, else None + - index=0: source image tokens (use normal timestep conditioning) + - index=1: target image tokens (use zero timestep conditioning) + + Note: _sp_plan will shard modulate_index via split_output=True when SP is enabled. + The modulate_index sequence dimension must match hidden_states after sharding. + """ + if self.zero_cond_t: + # Double the timestep: [timestep, timestep * 0] + # This creates two sets of conditioning parameters in AdaLayerNorm + timestep = torch.cat([timestep, timestep * 0], dim=0) + + # Create modulate_index to select conditioning per token position + # - First image (sample[0]): source image, use index=0 (normal timestep) + # - Remaining images (sample[1:]): target images, use index=1 (zero timestep) + modulate_index = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=timestep.device, + dtype=torch.int, + ) + return timestep, modulate_index + + return timestep, None + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, use_additional_t_cond=False): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: + self.addition_t_embedding = nn.Embedding(2, embedding_dim) + + def forward(self, timestep, hidden_states, addition_t_cond=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + if self.use_additional_t_cond: + if addition_t_cond is None: + raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.") + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype) + conditioning = conditioning + addition_t_emb + + return conditioning + + +class QwenEmbedLayer3DRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + layer_num = len(video_fhw) - 1 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + if idx != layer_num: + video_freq = self._compute_video_freqs(frame, height, width, idx) + else: + ### For the condition image, we set the layer index to -1 + video_freq = self._compute_condition_freqs(frame, height, width) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_vid_index = max(max_vid_index, layer_num) + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.cache + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + @functools.cache + def _compute_condition_freqs(self, frame, height, width): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index: torch.Tensor, dim: int, theta: int = 10000): + """ + Args: + index (`torch.Tensor`): [0, 1, 2, 3] 1D Tensor representing the position index of the token + dim (`int`): Dimension for the rope parameters + theta (`int`): Theta parameter for rope + """ + assert dim % 2 == 0 + freqs = torch.outer( + index, + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq = self.rope_cache[rope_key] + else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.cache + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], + dim=0, + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat( + [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], + dim=0, + ) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class ColumnParallelApproxGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + ) -> None: + super().__init__() + + assert activation_fn == "gelu-approximate", "Only gelu-approximate is supported." + + inner_dim = inner_dim or int(dim * mult) + dim_out = dim_out or dim + + layers: list[nn.Module] = [ + ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + nn.Identity(), # placeholder for weight loading + RowParallelLinear( + inner_dim, + dim_out, + input_is_parallel=True, + return_bias=False, + ), + ] + + self.net = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class QwenImageCrossAttention(nn.Module): + def __init__( + self, + dim: int, # query_dim + num_heads: int, + head_dim: int, + added_kv_proj_dim: int, + window_size: tuple[int, int] = (-1, -1), + out_bias: bool = True, + qk_norm: bool = True, + eps: float = 1e-6, + pre_only: bool = False, + context_pre_only: bool = False, + out_dim: int | None = None, + ) -> None: + super().__init__() + assert dim % num_heads == 0 + + self.dim = dim + self.head_dim = head_dim + self.total_num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.head_dim, + total_num_heads=num_heads, + ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.inner_dim = out_dim if out_dim is not None else head_dim * self.total_num_heads + + assert context_pre_only is not None + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=head_dim, + total_num_heads=num_heads, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + + assert not context_pre_only + self.to_add_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + assert not pre_only + self.to_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + self.norm_added_q = RMSNorm(head_dim, eps=eps) + self.norm_added_k = RMSNorm(head_dim, eps=eps) + + self.attn = Attention( + num_heads=self.query_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.kv_num_heads, + ) + self.rope = RotaryEmbedding(is_neox_style=False) + + try: + config = get_forward_context().omni_diffusion_config + self.parallel_config = config.parallel_config + except Exception: + self.parallel_config = None + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + vid_freqs: torch.Tensor, + txt_freqs: torch.Tensor, + hidden_states_mask: torch.Tensor | None = None, + encoder_hidden_states_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + img_qkv, _ = self.to_qkv(hidden_states) + q_size = self.query_num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + img_query, img_key, img_value = img_qkv.split([q_size, kv_size, kv_size], dim=-1) + + txt_qkv, _ = self.add_kv_proj(encoder_hidden_states) + add_q_size = self.add_query_num_heads * self.head_dim + add_kv_size = self.add_kv_num_heads * self.head_dim + txt_query, txt_key, txt_value = txt_qkv.split([add_q_size, add_kv_size, add_kv_size], dim=-1) + + img_query = img_query.unflatten(-1, (self.query_num_heads, self.head_dim)) + img_key = img_key.unflatten(-1, (self.kv_num_heads, self.head_dim)) + img_value = img_value.unflatten(-1, (self.kv_num_heads, self.head_dim)) + + txt_query = txt_query.unflatten(-1, (self.add_query_num_heads, self.head_dim)) + txt_key = txt_key.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) + txt_value = txt_value.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + img_cos = vid_freqs.real.to(img_query.dtype) + img_sin = vid_freqs.imag.to(img_query.dtype) + txt_cos = txt_freqs.real.to(txt_query.dtype) + txt_sin = txt_freqs.imag.to(txt_query.dtype) + + img_query = self.rope(img_query, img_cos, img_sin) + img_key = self.rope(img_key, img_cos, img_sin) + txt_query = self.rope(txt_query, txt_cos, txt_sin) + txt_key = self.rope(txt_key, txt_cos, txt_sin) + + seq_len_txt = encoder_hidden_states.shape[1] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + if ( + self.parallel_config is not None + and self.parallel_config.sequence_parallel_size > 1 + and not get_forward_context().split_text_embed_in_sp + ): + attn_metadata = AttentionMetadata( + joint_query=txt_query, + joint_key=txt_key, + joint_value=txt_value, + joint_strategy="front", + ) + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + if encoder_hidden_states_mask is not None: + attn_metadata.joint_attn_mask = encoder_hidden_states_mask + + joint_hidden_states = self.attn(img_query, img_key, img_value, attn_metadata) + else: + attn_metadata = None + if hidden_states_mask is not None or encoder_hidden_states_mask is not None: + mask_list: list[torch.Tensor] = [] + if encoder_hidden_states_mask is not None: + mask_list.append(encoder_hidden_states_mask) + else: + mask_list.append( + torch.ones( + encoder_hidden_states.shape[:2], + dtype=torch.bool, + device=encoder_hidden_states.device, + ) + ) + if hidden_states_mask is not None: + mask_list.append(hidden_states_mask) + else: + mask_list.append( + torch.ones( + hidden_states.shape[:2], + dtype=torch.bool, + device=hidden_states.device, + ) + ) + joint_mask = torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] + attn_metadata = AttentionMetadata(attn_mask=joint_mask) + + joint_hidden_states = self.attn(joint_query, joint_key, joint_value, attn_metadata) + + joint_hidden_states = joint_hidden_states.flatten(2, 3).to(joint_query.dtype) + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] + + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + zero_cond_t: bool = False, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.img_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = QwenImageCrossAttention( + dim=dim, + num_heads=num_attention_heads, + added_kv_proj_dim=dim, + context_pre_only=False, + head_dim=attention_head_dim, + ) + self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim=dim, dim_out=dim) + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + self.txt_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim) + + self.zero_cond_t = zero_cond_t + + def _modulate(self, x, mod_params, index=None): + """Apply modulation to input tensor""" + # x: b l d, shift: b d, scale: b d, gate: b d + shift, scale, gate = mod_params.chunk(3, dim=-1) + + if index is not None: + # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + # So shift, scale, gate have shape [2*actual_batch, d] + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + + # index: [b, l] where b is actual batch size + # Expand to [b, l, 1] to match feature dimension + index_expanded = index.unsqueeze(-1) # [b, l, 1] + + # Expand chunks to [b, 1, d] then broadcast to [b, l, d] + shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d] + shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d] + scale_0_exp = scale_0.unsqueeze(1) + scale_1_exp = scale_1.unsqueeze(1) + gate_0_exp = gate_0.unsqueeze(1) + gate_1_exp = gate_1.unsqueeze(1) + + # Use torch.where to select based on index + shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp) + scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp) + gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp) + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + + return x * (1 + scale_result) + shift_result, gate_result + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor], + joint_attention_kwargs: dict[str, Any] | None = None, + modulate_index: list[int] | None = None, + hidden_states_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + + if self.zero_cond_t: + temb = torch.chunk(temb, 2, dim=0)[0] + + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index) + + # Process text stream - norm1 + modulation + txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + vid_freqs=image_rotary_emb[0], + txt_freqs=image_rotary_emb[1], + hidden_states_mask=hidden_states_mask, + encoder_hidden_states_mask=encoder_hidden_states_mask, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index) + + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_modulated2, txt_gate2 = self.txt_norm2(encoder_hidden_states, txt_mod2) + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +# Note: inheriting from CachedTransformer only when we support caching +class QwenImageTransformer2DModel(CachedTransformer): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + # the small and frequently-repeated block(s) of a model + # -- typically a transformer layer + # used for torch compile optimizations + _repeated_blocks = ["QwenImageTransformerBlock"] + _layerwise_offload_blocks_attr = "transformer_blocks" + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + + # Sequence Parallelism plan (following diffusers' _cp_plan pattern) + # Similar to Z-Image's UnifiedPrepare, we use ImageRopePrepare to create + # a module boundary where _sp_plan can shard hidden_states and vid_freqs together. + # + # Key insight: hidden_states and vid_freqs MUST be sharded together to maintain + # dimension alignment for RoPE computation in attention layers. + # + # auto_pad=True enables automatic padding when sequence length is not divisible + # by SP world size. This creates an attention mask stored in ForwardContext + # that attention layers can use to ignore padding positions. + # + # Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism) + _sp_plan = { + # Shard ImageRopePrepare outputs (hidden_states and vid_freqs must be sharded together) + "image_rope_prepare": { + # hidden_states: auto_pad=True for variable sequence length support + 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True), + # vid_freqs: auto_pad=True to match hidden_states padding + 1: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + # txt_freqs (index 2) is NOT sharded - kept replicated for dual-stream attention + }, + # Shard ModulateIndexPrepare output (modulate_index must be sharded to match hidden_states) + # This is only active when zero_cond_t=True (image editing models) + # Output index 1 is modulate_index [batch, seq_len], needs sharding along dim=1 + "modulate_index_prepare": { + 1: SequenceParallelInput(split_dim=1, expected_dims=2, split_output=True, auto_pad=True), + }, + # Gather output at proj_out + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + def __init__( + self, + od_config: OmniDiffusionConfig, + patch_size: int = 2, + in_channels: int = 64, + out_channels: int | None = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, + axes_dims_rope: tuple[int, int, int] = (16, 56, 56), + zero_cond_t: bool = False, + use_additional_t_cond: bool = False, + use_layer3d_rope: bool = False, + ): + super().__init__() + self.parallel_config = od_config.parallel_config + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.guidance_embeds = guidance_embeds + + if not use_layer3d_rope: + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + else: + self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings( + embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond + ) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + zero_cond_t=zero_cond_t, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + self.zero_cond_t = zero_cond_t + + # ImageRopePrepare module for _sp_plan to shard hidden_states and vid_freqs together + # This ensures RoPE dimensions align with hidden_states after sharding + self.image_rope_prepare = ImageRopePrepare(self.img_in, self.pos_embed) + + # ModulateIndexPrepare module for _sp_plan to shard modulate_index + # This ensures modulate_index dimensions align with hidden_states after sharding + # Only active when zero_cond_t=True (image editing models) + self.modulate_index_prepare = ModulateIndexPrepare(zero_cond_t=zero_cond_t) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: list[tuple[int, int, int]] | None = None, + txt_seq_lens: list[int] | None = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: dict[str, Any] | None = None, + additional_t_cond=None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # if attention_kwargs is not None: + # attention_kwargs = attention_kwargs.copy() + # lora_scale = attention_kwargs.pop("scale", 1.0) + # else: + # lora_scale = 1.0 + + # Set split_text_embed_in_sp = False for dual-stream attention + # QwenImage uses *dual-stream* (text + image) and runs a *joint attention*. + # Text embeddings must be replicated across SP ranks for correctness. + if self.parallel_config.sequence_parallel_size > 1: + get_forward_context().split_text_embed_in_sp = False + + # Prepare hidden_states and RoPE via ImageRopePrepare module + # _sp_plan will shard hidden_states and vid_freqs together via split_output=True + # txt_freqs is kept replicated for dual-stream attention + hidden_states, vid_freqs, txt_freqs = self.image_rope_prepare(hidden_states, img_shapes, txt_seq_lens) + image_rotary_emb = (vid_freqs, txt_freqs) + + # Ensure timestep tensor is on the same device and dtype as hidden_states + timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype) + + # Prepare timestep and modulate_index via ModulateIndexPrepare module + # _sp_plan will shard modulate_index via split_output=True (when zero_cond_t=True) + # This ensures modulate_index sequence dimension matches sharded hidden_states + timestep, modulate_index = self.modulate_index_prepare(timestep, img_shapes) + + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states, additional_t_cond) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) + ) + + # Check for SP auto_pad: create attention mask dynamically if padding was applied + # In Ulysses mode, attention is computed on the FULL sequence (after All-to-All) + hidden_states_mask = None # default + if self.parallel_config is not None and self.parallel_config.sequence_parallel_size > 1: + ctx = get_forward_context() + if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: + # Create mask for the full (padded) sequence + # valid positions = True, padding positions = False + batch_size = hidden_states.shape[0] + padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size + hidden_states_mask = torch.ones( + batch_size, + padded_seq_len, + dtype=torch.bool, + device=hidden_states.device, + ) + hidden_states_mask[:, ctx.sp_original_seq_len :] = False + + # if mask is all true, set it to None + if hidden_states_mask is not None and hidden_states_mask.all(): + hidden_states_mask = None + if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.all(): + encoder_hidden_states_mask = None + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + modulate_index=modulate_index, + hidden_states_mask=hidden_states_mask, + ) + + if self.zero_cond_t: + temb = temb.chunk(2, dim=0)[0] + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + # Note: SP gather is handled automatically by _sp_plan's SequenceParallelGatherHook + # on proj_out output. No manual all_gather needed here. + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self-attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # cross-attn + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + # we need to load the buffers for beta and eps (XIELU) + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + original_name = name + lookup_name = name + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in original_name: + continue + lookup_name = original_name.replace(weight_name, param_name) + param = params_dict[lookup_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if lookup_name not in params_dict and ".to_out.0." in lookup_name: + lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + param = params_dict[lookup_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(original_name) + loaded_params.add(lookup_name) + return loaded_params diff --git a/vllm_omni/diffusion/models/schedulers/__init__.py b/vllm_omni/diffusion/models/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8df78ebf00ed4978ba56a61398d0acdd7608e0 --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, +) + +__all__ = [ + "FlowUniPCMultistepScheduler", +] diff --git a/vllm_omni/diffusion/models/schedulers/base.py b/vllm_omni/diffusion/models/schedulers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9d87d7f55a33a097e99bd37eeb7bd59e68c796 --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/base.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/hao-ai-lab/FastVideo +# Originally from https://github.com/huggingface/diffusers +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +"""Base scheduler class for diffusion models.""" + +from abc import ABC, abstractmethod + +import torch + + +class BaseScheduler(ABC): + """ + Abstract base class for schedulers. + + Subclasses must define: + - timesteps: torch.Tensor + - order: int + - num_train_timesteps: int + """ + + timesteps: torch.Tensor + order: int + num_train_timesteps: int + + def __init__(self): + required_attrs = ["timesteps", "order", "num_train_timesteps"] + for attr in required_attrs: + if not hasattr(self, attr): + raise AttributeError( + f"Subclass {self.__class__.__name__} must define `{attr}` before calling super().__init__()" + ) + + @abstractmethod + def set_shift(self, shift: float) -> None: + """Set the shift parameter for the scheduler.""" + raise NotImplementedError + + @abstractmethod + def set_timesteps(self, *args, **kwargs) -> None: + """Set the timesteps for the scheduler.""" + raise NotImplementedError + + @abstractmethod + def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: + """Scale the model input.""" + raise NotImplementedError diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..3efe564bc618dbc935ff69db5b86f8fc1a538e5a --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py @@ -0,0 +1,741 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/hao-ai-lab/FastVideo +# Originally from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +""" +FlowUniPCMultistepScheduler - A training-free framework for fast sampling of flow-matching diffusion models. + +This scheduler implements the UniPC (Unified Predictor-Corrector) algorithm adapted for flow matching, +providing faster convergence than simple Euler methods while maintaining quality. +""" + +from __future__ import annotations + +import math +from typing import Any + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from diffusers.utils import deprecate + +from vllm_omni.diffusion.models.schedulers.base import BaseScheduler + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + `FlowUniPCMultistepScheduler` is a training-free framework designed for the fast sampling of + flow-matching diffusion models. + + This scheduler implements the UniPC (Unified Predictor-Corrector) algorithm adapted for flow matching, + which can achieve the same quality as Euler methods in fewer steps (typically 20-30 steps vs 40-50). + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler. + shift (`float`, defaults to 1.0): + The shift parameter for the noise schedule. For Wan2.2: use 5.0 for 720p, 12.0 for 480p. + use_dynamic_shifting (`bool`, defaults to False): + Whether to use dynamic shifting based on image resolution. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. Use `bh1` for unconditional sampling when steps < 10, `bh2` otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Stabilizes sampling for steps < 15. + disable_corrector (`list`, default `[]`): + Steps to disable the corrector to mitigate misalignment with large guidance scales. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule. Either `"zero"` or `"sigma_min"`. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: float | None = 1.0, + use_dynamic_shifting: bool = False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: tuple = (), + solver_p: SchedulerMixin | None = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", + **kwargs, + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + self.num_inference_steps: int | None = None + + # Initialize sigma schedule + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # Apply timestep shifting based on shift parameter + assert shift is not None + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + self.num_train_timesteps = num_train_timesteps + + # State for multistep solver + self.model_outputs: list[torch.Tensor | None] = [None] * solver_order + self.timestep_list: list[Any | None] = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = list(disable_corrector) + self.solver_p = solver_p + self.last_sample: torch.Tensor | None = None + self._step_index: int | None = None + self._begin_index: int | None = None + self.this_order: int = 1 + + # Move sigmas to CPU to reduce GPU/CPU communication + self.sigmas = self.sigmas.to("cpu") + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + BaseScheduler.__init__(self) + + @property + def step_index(self) -> int | None: + """The index counter for current timestep. Increases by 1 after each scheduler step.""" + return self._step_index + + @property + def begin_index(self) -> int | None: + """The index for the first timestep. Should be set from pipeline with `set_begin_index` method.""" + return self._begin_index + + def set_shift(self, shift: float) -> None: + """Set the shift parameter for the scheduler.""" + self.config.shift = shift + + def set_begin_index(self, begin_index: int = 0) -> None: + """ + Sets the begin index for the scheduler. Run from pipeline before inference. + + Args: + begin_index (`int`): The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + sigmas: list[float] | None = None, + mu: float | None = None, + shift: float | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain (run before inference). + + Args: + num_inference_steps (`int`): + Total number of timesteps. + device (`str` or `torch.device`, *optional*): + The device to move timesteps to. + sigmas (`list[float]`, *optional*): + Custom sigma schedule. + mu (`float`, *optional*): + Parameter for dynamic shifting. + shift (`float`, *optional*): + Override shift parameter. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("Must pass a value for `mu` when `use_dynamic_shifting` is True") + + if sigmas is None: + assert num_inference_steps is not None + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] + + if self.config.use_dynamic_shifting: + assert mu is not None + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + if shift is None: + shift = self.config.shift + assert isinstance(sigmas, np.ndarray) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError(f"`final_sigmas_type` must be 'zero' or 'sigma_min', got {self.config.final_sigmas_type}") + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + # Reset state + self.model_outputs = [None] * self.config.solver_order + self.timestep_list = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + Dynamic thresholding to prevent pixel saturation. + + From "Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding" + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() + + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + abs_sample = sample.abs() + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp(s, min=1, max=self.config.sample_max_value) + s = s.unsqueeze(1) + sample = torch.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor: + """Convert sigma to timestep.""" + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert sigma to alpha and sigma_t for flow matching.""" + return 1 - sigma, sigma + + def time_shift(self, mu: float, sigma: float, t: np.ndarray) -> np.ndarray: + """Apply time shift transformation.""" + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the format needed by the UniPC algorithm. + + Args: + model_output (`torch.Tensor`): Direct output from the diffusion model. + sample (`torch.Tensor`): Current sample in the diffusion process. + + Returns: + `torch.Tensor`: Converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion " + "is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index].to(sample.device) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = sigma.to(sample.device) + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be `flow_prediction` " + "for the FlowUniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = sigma.to(sample.device) + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be `flow_prediction` " + "for the FlowUniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = sigma.to(sample.device) + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor | None = None, + order: int | None = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version) predictor. + + Args: + model_output (`torch.Tensor`): Direct output from the diffusion model. + sample (`torch.Tensor`): Current sample. + order (`int`): The order of UniP at this timestep. + + Returns: + `torch.Tensor`: The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError("missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect.", + ) + + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + device = sample.device + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1].to(device), + self.sigmas[self.step_index].to(device), + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + + rks = [] + D1s: list[Any] | None = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device)) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + assert mi is not None + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if D1s is not None and len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + assert isinstance(R, torch.Tensor) + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor | None = None, + this_sample: torch.Tensor | None = None, + order: int | None = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version) corrector. + + Args: + this_model_output (`torch.Tensor`): Model outputs at `x_t`. + last_sample (`torch.Tensor`): Sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): Sample after the last predictor `x_{t}`. + order (`int`): The order of UniC-p. Effective accuracy is `order + 1`. + + Returns: + `torch.Tensor`: The corrected sample tensor. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError("missing `last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError("missing `this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError("missing `order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect.", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + device = this_sample.device + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index].to(device), + self.sigmas[self.step_index - 1].to(device), + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + + rks = [] + D1s: list[Any] | None = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device)) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + assert mi is not None + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if D1s is not None and len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep: torch.Tensor, schedule_timesteps: torch.Tensor | None = None) -> int: + """Get the index for a given timestep.""" + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + pos = 1 if len(indices) > 1 else 0 + step_index: int = indices[pos].item() + + return step_index + + def _init_step_index(self, timestep: torch.Tensor) -> None: + """Initialize the step_index counter for the scheduler.""" + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> SchedulerOutput | tuple: + """ + Predict the sample from the previous timestep by reversing the SDE using multistep UniPC. + + Args: + model_output (`torch.Tensor`): Direct output from the diffusion model. + timestep (`int`): Current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): Current sample created by the diffusion process. + return_dict (`bool`): Whether to return a SchedulerOutput or tuple. + + Returns: + `SchedulerOutput` or `tuple`: The sample tensor at the previous timestep. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + # Update model output history + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + # Determine order for this step + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + assert self._step_index is not None + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input. + + Args: + sample (`torch.Tensor`): The input sample. + + Returns: + `torch.Tensor`: A scaled input sample (unchanged for this scheduler). + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples (`torch.Tensor`): Original samples. + noise (`torch.Tensor`): Noise to add. + timesteps (`torch.IntTensor`): Timesteps for noise addition. + + Returns: + `torch.Tensor`: Noisy samples. + """ + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + step_indices = [self.step_index] * timesteps.shape[0] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self) -> int: + return self.config.num_train_timesteps diff --git a/vllm_omni/diffusion/models/sd3/__init__.py b/vllm_omni/diffusion/models/sd3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7efafd9414d84dccae0c40f9447eef92863d31b --- /dev/null +++ b/vllm_omni/diffusion/models/sd3/__init__.py @@ -0,0 +1,15 @@ +"""Stable diffusion3 model components.""" + +from vllm_omni.diffusion.models.sd3.pipeline_sd3 import ( + StableDiffusion3Pipeline, + get_sd3_image_post_process_func, +) +from vllm_omni.diffusion.models.sd3.sd3_transformer import ( + SD3Transformer2DModel, +) + +__all__ = [ + "StableDiffusion3Pipeline", + "SD3Transformer2DModel", + "get_sd3_image_post_process_func", +] diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..3668c132f53a82adbd787bfaa7741d8cbb994a99 --- /dev/null +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -0,0 +1,699 @@ +import inspect +import json +import logging +import os +from collections.abc import Iterable + +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.sd3.sd3_transformer import ( + SD3Transformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +def get_sd3_image_post_process_func( + od_config: OmniDiffusionConfig, +): + if od_config.output_type == "latent": + return lambda x: x + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Pipeline(nn.Module, CFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self.device = get_local_device() + model = od_config.model + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_2 = CLIPTokenizer.from_pretrained( + model, subfolder="tokenizer_2", local_files_only=local_files_only + ) + self.tokenizer_3 = T5Tokenizer.from_pretrained( + model, subfolder="tokenizer_3", local_files_only=local_files_only + ) + self.text_encoder = CLIPTextModelWithProjection.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + model, subfolder="text_encoder_2", local_files_only=local_files_only + ) + self.text_encoder_3 = T5EncoderModel.from_pretrained( + model, + subfolder="text_encoder_3", + local_files_only=local_files_only, + ) + self.transformer = SD3Transformer2DModel(od_config=od_config) + + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + self.patch_size = 2 + self.output_type = self.od_config.output_type + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by " + f"{self.vae_scale_factor * self.patch_size} but are " + f"{height} and {width}. You can use height " + f"{height - height % (self.vae_scale_factor * self.patch_size)} " + f"and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str] = "", + num_images_per_prompt: int = 1, + dtype: torch.dtype | None = None, + clip_model_index: int = 0, + ): + dtype = dtype or self.text_encoder.dtype + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(self.device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = "", + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + dtype: torch.dtype | None = None, + ): + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size, + max_sequence_length, + self.transformer.joint_attention_dim, + ), + device=self.device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(self.device) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(self.device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + num_images_per_prompt: int = 1, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + + pooled_prompt_embeds = None + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + return prompt_embeds, pooled_prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None): + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + **scheduler_kwargs, + ) + return timesteps, num_inference_steps + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def diffuse( + self, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_pooled_prompt_embeds: torch.Tensor, + do_true_cfg: bool, + guidance_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor: + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + timesteps: Diffusion timesteps + prompt_embeds: Positive prompt embeddings + pooled_prompt_embeds: Pooled positive prompt embeddings + negative_prompt_embeds: Negative prompt embeddings + negative_pooled_prompt_embeds: Pooled negative prompt embeddings + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ + self.scheduler.set_begin_index(0) + + for _, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + + # Broadcast timestep to match batch size + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "return_dict": False, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "pooled_projections": negative_pooled_prompt_embeds, + "return_dict": False, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + return latents + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] = "", + prompt_2: str | list[str] = "", + prompt_3: str | list[str] = "", + negative_prompt: str | list[str] = "", + negative_prompt_2: str | list[str] = "", + negative_prompt_3: str | list[str] = "", + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + ) -> DiffusionOutput: + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + negative_prompt = [ + "" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts + ] or negative_prompt + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + # 1. check inputs + # 2. encode prompts + # 3. prepare latents and timesteps + # 4. diffusion process + # 5. decode latents + # 6. post-process outputs + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = req.sampling_params.guidance_scale + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_embeds, pooled_prompt_embeds = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + do_cfg = self.guidance_scale > 1 + if do_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_3=negative_prompt_3, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + num_channels_latents = self.transformer.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + + timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) + self._num_timesteps = len(timesteps) + + # Denoising loop using diffuse method + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds if do_cfg else None, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds if do_cfg else None, + do_true_cfg=do_cfg, + guidance_scale=self.guidance_scale, + cfg_normalize=False, + ) + + self._current_timestep = None + if self.output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/sd3/sd3_transformer.py b/vllm_omni/diffusion/models/sd3/sd3_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e60bcbe5a14f8cb8d32dc942ac696f427b101664 --- /dev/null +++ b/vllm_omni/diffusion/models/sd3/sd3_transformer.py @@ -0,0 +1,479 @@ +from collections.abc import Iterable + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward + +# TODO replace this with vLLM implementation +from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig + +logger = init_logger(__name__) + + +class SD3PatchEmbed(nn.Module): + """ + 2D Image to Patch Embedding with support for SD3. + + Args: + patch_size (`int`, defaults to `16`): The size of the patches. + in_channels (`int`, defaults to `3`): The number of input channels. + embed_dim (`int`, defaults to `768`): The output dimension of the embedding. + """ + + def __init__( + self, + patch_size=16, + in_channels=3, + embed_dim=768, + ): + super().__init__() + + self.patch_size = patch_size + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True + ) + + def forward(self, latent): + x = self.proj(latent) # [B, embed_dim, patch_size, patch_size] + x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim] + return x + + +class SD3CrossAttention(nn.Module): + def __init__( + self, + dim: int, # query_dim + num_heads: int, + head_dim: int, + added_kv_proj_dim: int = 0, + out_bias: bool = True, + qk_norm=True, # rmsnorm + eps=1e-6, + pre_only=False, + context_pre_only: bool = False, + parallel_attention=False, + out_dim: int = 0, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.head_dim, + total_num_heads=num_heads, + disable_tp=True, + ) + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads + self.inner_kv_dim = self.inner_dim + if added_kv_proj_dim is not None: + self.add_kv_proj = QKVParallelLinear( + added_kv_proj_dim, + head_size=self.inner_kv_dim // self.num_heads, + total_num_heads=self.num_heads, + disable_tp=True, + ) + + if not context_pre_only: + self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) + else: + self.to_add_out = None + + if not pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)) + else: + self.to_out = None + + self.norm_added_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_added_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.attn = Attention( + num_heads=num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + ): + # Compute QKV for image stream (sample projections) + qkv, _ = self.to_qkv(hidden_states) + img_query, img_key, img_value = qkv.chunk(3, dim=-1) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (self.num_heads, -1)) + img_key = img_key.unflatten(-1, (self.num_heads, -1)) + img_value = img_value.unflatten(-1, (self.num_heads, -1)) + + # Apply QK normalization + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + + if encoder_hidden_states is not None: + # Compute QKV for text stream (context projections) + qkv, _ = self.add_kv_proj(encoder_hidden_states) + txt_query, txt_key, txt_value = qkv.chunk(3, dim=-1) + + txt_query = txt_query.unflatten(-1, (self.num_heads, -1)) + txt_key = txt_key.unflatten(-1, (self.num_heads, -1)) + txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) + + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + # Concatenate for joint attention + # Order: [text, image] + query = torch.cat([txt_query, img_query], dim=1) + key = torch.cat([txt_key, img_key], dim=1) + value = torch.cat([txt_value, img_value], dim=1) + else: + query = img_query + key = img_key + value = img_value + + hidden_states = self.attn( + query, + key, + value, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split attention outputs back + context_seqlen = encoder_hidden_states.shape[1] + hidden_states, encoder_hidden_states = ( + hidden_states[:, context_seqlen:, :], # Image part + hidden_states[:, :context_seqlen, :], # Text part + ) + if self.to_add_out is not None: + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + # Apply output projections + if self.to_out is not None: + hidden_states, _ = self.to_out[0](hidden_states) + + if encoder_hidden_states is None: + return hidden_states + else: + return hidden_states, encoder_hidden_states + + +class SD3TransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://huggingface.co/papers/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: str | None = None, + use_dual_attention: bool = False, + ): + super().__init__() + + self.use_dual_attention = use_dual_attention + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continuous" if context_pre_only else "ada_norm_zero" + + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continuous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently " + f"only support `ada_norm_continuous`, `ada_norm_zero`" + ) + + self.attn = SD3CrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + added_kv_proj_dim=dim, + context_pre_only=context_pre_only, + out_dim=dim, + qk_norm=True if qk_norm == "rms_norm" else False, + eps=1e-6, + ) + + if use_dual_attention: + self.attn2 = SD3CrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + out_dim=dim, + qk_norm=True if qk_norm == "rms_norm" else False, + eps=1e-6, + ) + else: + self.attn2 = None + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + +class SD3Transformer2DModel(nn.Module): + """ + The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). + """ + + _repeated_blocks = ["SD3TransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + + def __init__( + self, + od_config: OmniDiffusionConfig, + ): + super().__init__() + model_config = od_config.tf_model_config + self.num_layers = model_config.num_layers + self.parallel_config = od_config.parallel_config + self.sample_size = model_config.sample_size + self.in_channels = model_config.in_channels + self.out_channels = model_config.out_channels + self.num_attention_heads = model_config.num_attention_heads + self.attention_head_dim = model_config.attention_head_dim + self.inner_dim = model_config.num_attention_heads * model_config.attention_head_dim + self.caption_projection_dim = model_config.caption_projection_dim + self.pooled_projection_dim = model_config.pooled_projection_dim + self.joint_attention_dim = model_config.joint_attention_dim + self.patch_size = model_config.patch_size + self.dual_attention_layers = ( + model_config.dual_attention_layers if hasattr(model_config, "dual_attention_layers") else () + ) + self.qk_norm = model_config.qk_norm if hasattr(model_config, "qk_norm") else "" + self.pos_embed_max_size = model_config.pos_embed_max_size + + self.pos_embed = PatchEmbed( + height=self.sample_size, + width=self.sample_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=self.pos_embed_max_size, + ) + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=self.pooled_projection_dim + ) + self.context_embedder = nn.Linear(self.joint_attention_dim, self.caption_projection_dim) + + self.transformer_blocks = nn.ModuleList( + [ + SD3TransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + context_pre_only=i == self.num_layers - 1, + qk_norm=self.qk_norm, + use_dual_attention=True if i in self.dual_attention_layers else False, + ) + for i in range(self.num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, self.patch_size * self.patch_size * self.out_channels, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + pooled_projections: torch.Tensor, + timestep: torch.LongTensor, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self-attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # cross-attn + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".pos_embed"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/stable_audio/__init__.py b/vllm_omni/diffusion/models/stable_audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..baa986a0ffa01b16ab947ee684cbd0af967cb4d2 --- /dev/null +++ b/vllm_omni/diffusion/models/stable_audio/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Stable Audio Open model support for vLLM-Omni.""" + +from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import ( + StableAudioPipeline, + get_stable_audio_post_process_func, +) +from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import ( + StableAudioDiTModel, +) + +__all__ = [ + "StableAudioDiTModel", + "StableAudioPipeline", + "get_stable_audio_post_process_func", +] diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..c48d68efd6486f17bab193dfb7264e58d64d3045 --- /dev/null +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -0,0 +1,575 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Stable Audio Open Pipeline for vLLM-Omni. + +This module provides text-to-audio generation using the Stable Audio Open model +from Stability AI, integrated with the vLLM-Omni diffusion framework. +""" + +from __future__ import annotations + +import os +from collections.abc import Iterable + +import torch +from diffusers import AutoencoderOobleck +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel +from diffusers.schedulers import CosineDPMSolverMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import T5EncoderModel, T5TokenizerFast +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportAudioOutput +from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = init_logger(__name__) + + +def get_stable_audio_post_process_func( + od_config: OmniDiffusionConfig, +): + """ + Create post-processing function for Stable Audio output. + + Converts raw audio tensor to numpy array for saving. + """ + + def post_process_func( + audio: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return audio + if output_type == "pt": + return audio + # Convert to numpy + audio_np = audio.cpu().float().numpy() + return audio_np + + return post_process_func + + +class StableAudioPipeline(nn.Module, SupportAudioOutput): + """ + Pipeline for text-to-audio generation using Stable Audio Open. + + This pipeline generates audio from text prompts using the Stable Audio Open model + from Stability AI, integrated with vLLM-Omni's diffusion framework. + + Args: + od_config: OmniDiffusion configuration object + prefix: Weight prefix for loading (default: "") + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + + self.device = get_local_device() + dtype = getattr(od_config, "dtype", torch.float16) + + model = od_config.model + local_files_only = os.path.exists(model) + + # Set up weights sources for transformer + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + ] + + # Load tokenizer + self.tokenizer = T5TokenizerFast.from_pretrained( + model, + subfolder="tokenizer", + local_files_only=local_files_only, + ) + + # Load text encoder + self.text_encoder = T5EncoderModel.from_pretrained( + model, + subfolder="text_encoder", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + + # Load VAE (AutoencoderOobleck for audio) + self.vae = AutoencoderOobleck.from_pretrained( + model, + subfolder="vae", + torch_dtype=torch.float32, + local_files_only=local_files_only, + ).to(self.device) + + # Load projection model (using diffusers implementation) + self.projection_model = StableAudioProjectionModel.from_pretrained( + model, + subfolder="projection_model", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + + # Initialize our custom transformer (weights loaded via load_weights) + self.transformer = StableAudioDiTModel(od_config=od_config) + + # Load scheduler + self.scheduler = CosineDPMSolverMultistepScheduler.from_pretrained( + model, + subfolder="scheduler", + local_files_only=local_files_only, + ) + + # Compute rotary embedding dimension + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Cache backend (set by worker if needed) + self._cache_backend = None + + # Properties + self._guidance_scale = None + self._num_timesteps = None + self._current_timestep = None + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + def check_inputs( + self, + prompt: str | list[str] | None, + audio_start_in_s: float, + audio_end_in_s: float, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + """Validate input parameters.""" + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}` must be higher than `audio_start_in_s={audio_start_in_s}`" + ) + + min_val = self.projection_model.config.min_value + max_val = self.projection_model.config.max_value + + if audio_start_in_s < min_val or audio_start_in_s > max_val: + raise ValueError(f"`audio_start_in_s` must be between {min_val} and {max_val}, got {audio_start_in_s}") + + if audio_end_in_s < min_val or audio_end_in_s > max_val: + raise ValueError(f"`audio_end_in_s` must be between {min_val} and {max_val}, got {audio_end_in_s}") + + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`. Please provide only one.") + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device, + do_classifier_free_guidance: bool, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + negative_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Encode text prompt to embeddings.""" + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # Tokenize + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # Encode + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + )[0] + + # Handle negative prompt for CFG + if do_classifier_free_guidance and negative_prompt is not None: + if isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt` has batch size {len(negative_prompt)}, but `prompt` " + f"has batch size {batch_size}. Please make sure they match." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + )[0] + + if negative_attention_mask is not None: + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), + negative_prompt_embeds, + 0.0, + ) + + # Concatenate for CFG + if do_classifier_free_guidance and negative_prompt_embeds is not None: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + # Project embeddings + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s: float, + audio_end_in_s: float, + device: torch.device, + do_classifier_free_guidance: bool, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Encode audio duration to conditioning tensors.""" + audio_start_in_s = [audio_start_in_s] if isinstance(audio_start_in_s, (int, float)) else audio_start_in_s + audio_end_in_s = [audio_end_in_s] if isinstance(audio_end_in_s, (int, float)) else audio_end_in_s + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + audio_start_in_s = torch.tensor([float(x) for x in audio_start_in_s]).to(device) + audio_end_in_s = torch.tensor([float(x) for x in audio_end_in_s]).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + def prepare_latents( + self, + batch_size: int, + num_channels_vae: int, + sample_size: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """Prepare initial latent noise.""" + shape = (batch_size, num_channels_vae, sample_size) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # Scale by scheduler's noise sigma + latents = latents * self.scheduler.init_noise_sigma + return latents + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + audio_end_in_s: float | None = None, + audio_start_in_s: float = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + num_waveforms_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str = "np", + ) -> DiffusionOutput: + """ + Generate audio from text prompt. + + Args: + req: OmniDiffusionRequest containing generation parameters + prompt: Text prompt for audio generation + negative_prompt: Negative prompt for CFG + audio_end_in_s: Audio end time in seconds (max ~47s for stable-audio-open-1.0) + audio_start_in_s: Audio start time in seconds + num_inference_steps: Number of denoising steps + guidance_scale: CFG scale + num_waveforms_per_prompt: Number of audio outputs per prompt + generator: Random generator for reproducibility + latents: Pre-generated latents + prompt_embeds: Pre-computed prompt embeddings + negative_prompt_embeds: Pre-computed negative prompt embeddings + output_type: Output format ("np", "pt", or "latent") + + Returns: + DiffusionOutput containing generated audio + """ + # Extract from request + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + + if generator is None: + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) + + # Get audio duration from request extra params or defaults + audio_start_in_s = req.sampling_params.extra_args.get("audio_start_in_s", audio_start_in_s) + audio_end_in_s = req.sampling_params.extra_args.get("audio_end_in_s", audio_end_in_s) + + # Calculate audio length + downsample_ratio = self.vae.hop_length + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"Requested audio length ({audio_end_in_s - audio_start_in_s}s) exceeds " + f"maximum ({max_audio_length_in_s}s)" + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # Validate inputs + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # Determine batch size + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.device + do_classifier_free_guidance = guidance_scale > 1.0 + self._guidance_scale = guidance_scale + + # Encode prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create combined embeddings + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], + dim=1, + ) + audio_duration_embeds = torch.cat( + [seconds_start_hidden_states, seconds_end_hidden_states], + dim=2, + ) + + # Handle CFG without negative prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like(text_audio_duration_embeds) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], + dim=0, + ) + audio_duration_embeds = torch.cat( + [audio_duration_embeds, audio_duration_embeds], + dim=0, + ) + + # Duplicate for multiple waveforms per prompt + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # Prepare latents + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + ) + + # Prepare rotary embeddings and move to device + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + # Move rotary embeddings to device (returns tuple of cos, sin) + rotary_embedding = ( + rotary_embedding[0].to(device=device, dtype=latents.dtype), + rotary_embedding[1].to(device=device, dtype=latents.dtype), + ) + + # Denoising loop + for t in timesteps: + self._current_timestep = t + + # Expand latents for CFG + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Predict noise + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] + + # Perform CFG + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + self._current_timestep = None + + # Decode + if output_type == "latent": + audio = latents + else: + # Convert latents to VAE dtype (VAE may use float32) + latents_for_vae = latents.to(dtype=self.vae.dtype) + audio = self.vae.decode(latents_for_vae).sample + + # Trim to requested length + audio = audio[:, :, waveform_start:waveform_end] + + return DiffusionOutput(output=audio) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..22d56ac1fd1db77ad0635d01d0d67f0507384131 --- /dev/null +++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py @@ -0,0 +1,602 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Stable Audio DiT Model for vLLM-Omni. +""" + +import math +from collections.abc import Iterable + +import torch +import torch.nn as nn +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig + +logger = init_logger(__name__) + + +def apply_rotary_emb_stable_audio( + hidden_states: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors for Stable Audio. + + Args: + hidden_states: Input tensor of shape [B, S, H, D] where D is head_dim + freqs_cis: Tuple of (cos, sin) frequency tensors of shape [S, rotary_dim] + where rotary_dim = head_dim // 2 + + Returns: + Tensor with rotary embeddings applied to first rotary_dim dimensions only. + The remaining dimensions are left unchanged (pass-through). + """ + cos, sin = freqs_cis # [S, rotary_dim] + rotary_dim = cos.shape[-1] + + # Rotate only the first rotary_dim entries; leave the rest unchanged + x_rot = hidden_states[..., :rotary_dim] + x_pass = hidden_states[..., rotary_dim:] + + cos = cos[None, :, None, :] # [1, S, 1, rotary_dim] + sin = sin[None, :, None, :] # [1, S, 1, rotary_dim] + + # [B, S, H, rotary_dim] -> [B, S, H, 2, rotary_dim//2] -> two halves + x_real, x_imag = x_rot.reshape(*x_rot.shape[:-1], 2, rotary_dim // 2).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + + x_rot = (x_rot.float() * cos + x_rotated.float() * sin).to(hidden_states.dtype) + return torch.cat([x_rot, x_pass], dim=-1) + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels. + + Matches diffusers StableAudioGaussianFourierProjection with: + - flip_sin_to_cos=True (output is [cos, sin] not [sin, cos]) + - log=False (no log transformation of input) + """ + + def __init__(self, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x shape: [batch] or [batch, 1] + # Output: [batch, embedding_size * 2] + x_proj = 2 * math.pi * x[:, None] @ self.weight[None, :] + # flip_sin_to_cos=True means cos comes first + return torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + + +class StableAudioSelfAttention(nn.Module): + """ + Optimized self-attention for Stable Audio using vLLM layers. + + Self-attention uses full attention (all heads for Q, K, V). + GQA is only used for cross-attention. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim + + # All projections use inner_dim for output + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=False) + self.to_k = ReplicatedLinear(dim, self.inner_dim, bias=False) + self.to_v = ReplicatedLinear(dim, self.inner_dim, bias=False) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=False), + nn.Dropout(dropout), + ] + ) + + # Full attention (no GQA for self-attention) + self.attn = Attention( + num_heads=num_attention_heads, + head_size=attention_head_dim, + softmax_scale=1.0 / (attention_head_dim**0.5), + causal=False, + num_kv_heads=num_attention_heads, # Same as query heads + ) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + + # Projections - all output inner_dim + query, _ = self.to_q(hidden_states) + key, _ = self.to_k(hidden_states) + value, _ = self.to_v(hidden_states) + + # Reshape for multi-head attention (all use full heads) + query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # Apply rotary embeddings + if rotary_emb is not None: + query = apply_rotary_emb_stable_audio(query, rotary_emb) + key = apply_rotary_emb_stable_audio(key, rotary_emb) + + # Compute attention + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.view(batch_size, seq_len, self.inner_dim) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class StableAudioCrossAttention(nn.Module): + """ + Optimized cross-attention for Stable Audio using vLLM layers. + + For cross-attention: + - Q projection: outputs inner_dim (full heads) + - K/V projections: outputs kv_dim (reduced heads for GQA) + + GQA is handled by manually expanding K/V heads to match Q heads + since the SDPA backend doesn't handle this automatically. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_attention_heads + self.num_kv_heads = num_key_value_attention_heads + self.head_dim = attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim + self.kv_dim = num_key_value_attention_heads * attention_head_dim + + # Number of times to repeat KV heads + self.num_kv_groups = num_attention_heads // num_key_value_attention_heads + + # Q outputs inner_dim, K/V output kv_dim (GQA) + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=False) + self.to_k = ReplicatedLinear(cross_attention_dim, self.kv_dim, bias=False) + self.to_v = ReplicatedLinear(cross_attention_dim, self.kv_dim, bias=False) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=False), + nn.Dropout(dropout), + ] + ) + + # Use full heads for attention (KV will be expanded) + self.attn = Attention( + num_heads=num_attention_heads, + head_size=attention_head_dim, + softmax_scale=1.0 / (attention_head_dim**0.5), + causal=False, + num_kv_heads=num_attention_heads, # After expansion + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + encoder_seq_len = encoder_hidden_states.shape[1] + + # Projections + query, _ = self.to_q(hidden_states) + key, _ = self.to_k(encoder_hidden_states) + value, _ = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, encoder_seq_len, self.num_kv_heads, self.head_dim) + value = value.view(batch_size, encoder_seq_len, self.num_kv_heads, self.head_dim) + + # Expand K/V heads to match Q heads for GQA + # [B, S, kv_heads, D] -> [B, S, kv_heads, 1, D] -> [B, S, kv_heads, groups, D] -> [B, S, num_heads, D] + key = key.unsqueeze(3).expand(-1, -1, -1, self.num_kv_groups, -1) + key = key.reshape(batch_size, encoder_seq_len, self.num_heads, self.head_dim) + value = value.unsqueeze(3).expand(-1, -1, -1, self.num_kv_groups, -1) + value = value.reshape(batch_size, encoder_seq_len, self.num_heads, self.head_dim) + + # Compute attention + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.view(batch_size, seq_len, self.inner_dim) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class SwiGLU(nn.Module): + """SwiGLU activation - matches diffusers structure.""" + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class StableAudioFeedForward(nn.Module): + """ + Feed-forward network with SwiGLU activation for Stable Audio. + Matches diffusers FeedForward structure with activation_fn="swiglu". + """ + + def __init__(self, dim: int, inner_dim: int, bias: bool = True): + super().__init__() + # Structure matches diffusers FeedForward: + # net.0 = SwiGLU (proj.weight, proj.bias) + # net.1 = Dropout + # net.2 = Linear (weight, bias) + self.net = nn.Sequential( + SwiGLU(dim, inner_dim, bias=bias), + nn.Dropout(0.0), + nn.Linear(inner_dim, dim, bias=bias), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.net(hidden_states) + + +class StableAudioDiTBlock(nn.Module): + """ + Stable Audio DiT block with self-attention, cross-attention, and FFN. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + ff_mult: int = 4, + ): + super().__init__() + + # Self-attention with layer norm + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True) + self.attn1 = StableAudioSelfAttention( + dim=dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + ) + + # Cross-attention with layer norm + self.norm2 = nn.LayerNorm(dim, elementwise_affine=True) + self.attn2 = StableAudioCrossAttention( + dim=dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + + # Feed-forward with SwiGLU activation + # inner_dim = dim * ff_mult (e.g., 1536 * 4 = 6144) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=True) + self.ff = StableAudioFeedForward(dim, inner_dim=dim * ff_mult) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + rotary_embedding: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + # Self-attention with skip connection + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn1(hidden_states, rotary_emb=rotary_embedding, attention_mask=attention_mask) + hidden_states = residual + hidden_states + + # Cross-attention with skip connection + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2( + hidden_states, + encoder_hidden_states, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + hidden_states = residual + hidden_states + + # Feed-forward with skip connection + residual = hidden_states + hidden_states = self.norm3(hidden_states) + hidden_states = self.ff(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class StableAudioDiTModel(nn.Module): + """ + Optimized Stable Audio DiT model using vLLM layers. + + This is an optimized version of the diffusers StableAudioDiTModel that uses + vLLM's efficient linear layers and attention implementations. + + Architecture: + - Input: [B, in_channels, L] (e.g., [B, 64, L]) + - preprocess_conv: residual conv layer (keeps 64 channels) + - proj_in: projects 64 -> 1536 (inner_dim) + - Global+time embeddings prepended to sequence + - Transformer blocks work on 1536-dim + - proj_out: projects 1536 -> 64 (out_channels) + - postprocess_conv: residual conv layer (keeps 64 channels) + - Output: [B, out_channels, L] + """ + + def __init__( + self, + od_config: OmniDiffusionConfig | None = None, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + + # inner_dim is the transformer hidden dimension + self.inner_dim = num_attention_heads * attention_head_dim + + # Store config for compatibility + self.config = type( + "Config", + (), + { + "sample_size": sample_size, + "in_channels": in_channels, + "out_channels": out_channels, + "num_layers": num_layers, + "attention_head_dim": attention_head_dim, + "num_attention_heads": num_attention_heads, + "num_key_value_attention_heads": num_key_value_attention_heads, + "cross_attention_dim": cross_attention_dim, + "time_proj_dim": time_proj_dim, + "global_states_input_dim": global_states_input_dim, + "cross_attention_input_dim": cross_attention_input_dim, + }, + )() + + # Time projection (Gaussian Fourier features) + # time_proj_dim is the OUTPUT dimension (after sin/cos concatenation) + # So embedding_size = time_proj_dim // 2 + self.time_proj = StableAudioGaussianFourierProjection(embedding_size=time_proj_dim // 2) + + # Timestep projection: time_proj_dim -> inner_dim + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + # Global states projection (for audio duration conditioning) + # Output is inner_dim, added to time embedding + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + # Cross-attention input projection + # Always use Sequential(Linear, SiLU, Linear) to match diffusers structure + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + # Pre-processing conv (residual connection) + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + + # Input projection: in_channels -> inner_dim (64 -> 1536) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + # Transformer blocks - work on inner_dim (1536) + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + # Output projection: inner_dim -> out_channels (1536 -> 64) + self.proj_out = nn.Linear(self.inner_dim, out_channels, bias=False) + + # Post-processing conv (residual connection) + self.postprocess_conv = nn.Conv1d(out_channels, out_channels, 1, bias=False) + + @property + def dtype(self) -> torch.dtype: + """Return the dtype of the model parameters.""" + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + global_hidden_states: torch.Tensor | None = None, + rotary_embedding: tuple[torch.Tensor, torch.Tensor] | None = None, + return_dict: bool = True, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + Forward pass of the Stable Audio DiT model. + + Args: + hidden_states: Input latent tensor [B, C, L] (C=in_channels=64) + timestep: Timestep tensor [B] or [1] + encoder_hidden_states: Text/condition embeddings [B, S, D] + global_hidden_states: Global conditioning (duration) [B, 1, D] + rotary_embedding: Precomputed rotary embeddings (cos, sin) + return_dict: Whether to return a dataclass or tuple + attention_mask: Attention mask for self-attention + encoder_attention_mask: Attention mask for cross-attention + + Returns: + Denoised latent tensor + """ + # Project cross-attention inputs + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + + # Global embedding projection [B, 1, D] -> [B, 1, inner_dim] + global_hidden_states = self.global_proj(global_hidden_states) + + # Time embedding: timestep -> time_proj -> timestep_proj + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) + + # Combine global and time embeddings [B, 1, inner_dim] + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + # Pre-process with residual: [B, C, L] + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + + # Transpose: [B, C, L] -> [B, L, C] + hidden_states = hidden_states.transpose(1, 2) + + # Project to inner_dim: [B, L, C] -> [B, L, inner_dim] + hidden_states = self.proj_in(hidden_states) + + # Prepend global states to hidden states: [B, 1+L, inner_dim] + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=1) + + # Update attention mask if provided + if attention_mask is not None: + prepend_mask = torch.ones( + (hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=torch.bool, + ) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + # Transformer blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + + # Project back to out_channels: [B, 1+L, inner_dim] -> [B, 1+L, out_channels] + hidden_states = self.proj_out(hidden_states) + + # Transpose and remove prepended global token: [B, L, C] -> [B, C, L] + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + + # Post-process with residual: [B, C, L] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if return_dict: + return Transformer2DModelOutput(sample=hidden_states) + return (hidden_states,) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from a pretrained model. + + Maps diffusers weight names to our module structure. + + Returns: + Set of parameter names that were successfully loaded. + """ + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + # Weight name mapping from diffusers to our implementation + name_mapping = { + # Timestep projection - diffusers uses index-based naming + "timestep_proj.linear_1.weight": "timestep_proj.0.weight", + "timestep_proj.linear_1.bias": "timestep_proj.0.bias", + "timestep_proj.linear_2.weight": "timestep_proj.2.weight", + "timestep_proj.linear_2.bias": "timestep_proj.2.bias", + # Global projection - diffusers uses index-based naming + "global_proj.linear_1.weight": "global_proj.0.weight", + "global_proj.linear_2.weight": "global_proj.2.weight", + } + + for name, loaded_weight in weights: + # Apply name mapping if needed + mapped_name = name_mapping.get(name, name) + + if mapped_name in params_dict: + param = params_dict[mapped_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(mapped_name) + else: + logger.debug(f"Skipping weight {name} - not found in model") + + return loaded_params diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c337f58a4a292a0778f8e774ca1e4acad267683c --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/__init__.py @@ -0,0 +1,35 @@ +from .pipeline_wan2_2 import ( + Wan22Pipeline, + create_transformer_from_config, + get_wan22_post_process_func, + get_wan22_pre_process_func, + load_transformer_config, + retrieve_latents, +) +from .pipeline_wan2_2_i2v import ( + Wan22I2VPipeline, + get_wan22_i2v_post_process_func, + get_wan22_i2v_pre_process_func, +) +from .pipeline_wan2_2_ti2v import ( + Wan22TI2VPipeline, + get_wan22_ti2v_post_process_func, + get_wan22_ti2v_pre_process_func, +) +from .wan2_2_transformer import WanTransformer3DModel + +__all__ = [ + "Wan22Pipeline", + "get_wan22_post_process_func", + "get_wan22_pre_process_func", + "retrieve_latents", + "load_transformer_config", + "create_transformer_from_config", + "Wan22I2VPipeline", + "get_wan22_i2v_post_process_func", + "get_wan22_i2v_pre_process_func", + "Wan22TI2VPipeline", + "get_wan22_ti2v_post_process_func", + "get_wan22_ti2v_pre_process_func", + "WanTransformer3DModel", +] diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py new file mode 100644 index 0000000000000000000000000000000000000000..b902bc692e142ebcc6977ca5913dba8355726173 --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -0,0 +1,810 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import json +import logging +import os +from collections.abc import Iterable +from typing import Any, cast + +import PIL.Image +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, UMT5EncoderModel +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform + +logger = logging.getLogger(__name__) + + +def retrieve_latents( + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", +): + """Retrieve latents from VAE encoder output.""" + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def load_transformer_config(model_path: str, subfolder: str = "transformer", local_files_only: bool = True) -> dict: + """Load transformer config from model directory or HF Hub.""" + if local_files_only: + config_path = os.path.join(model_path, subfolder, "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + return json.load(f) + else: + # Try to download config from HF Hub + try: + from huggingface_hub import hf_hub_download + + config_path = hf_hub_download( + repo_id=model_path, + filename=f"{subfolder}/config.json", + ) + with open(config_path) as f: + return json.load(f) + except Exception: + pass + return {} + + +def create_transformer_from_config(config: dict) -> WanTransformer3DModel: + """Create WanTransformer3DModel from config dict.""" + kwargs = {} + + if "patch_size" in config: + kwargs["patch_size"] = tuple(config["patch_size"]) + if "num_attention_heads" in config: + kwargs["num_attention_heads"] = config["num_attention_heads"] + if "attention_head_dim" in config: + kwargs["attention_head_dim"] = config["attention_head_dim"] + if "in_channels" in config: + kwargs["in_channels"] = config["in_channels"] + if "out_channels" in config: + kwargs["out_channels"] = config["out_channels"] + if "text_dim" in config: + kwargs["text_dim"] = config["text_dim"] + if "freq_dim" in config: + kwargs["freq_dim"] = config["freq_dim"] + if "ffn_dim" in config: + kwargs["ffn_dim"] = config["ffn_dim"] + if "num_layers" in config: + kwargs["num_layers"] = config["num_layers"] + if "cross_attn_norm" in config: + kwargs["cross_attn_norm"] = config["cross_attn_norm"] + if "eps" in config: + kwargs["eps"] = config["eps"] + if "image_dim" in config: + kwargs["image_dim"] = config["image_dim"] + if "added_kv_proj_dim" in config: + kwargs["added_kv_proj_dim"] = config["added_kv_proj_dim"] + if "rope_max_seq_len" in config: + kwargs["rope_max_seq_len"] = config["rope_max_seq_len"] + if "pos_embed_seq_len" in config: + kwargs["pos_embed_seq_len"] = config["pos_embed_seq_len"] + + return WanTransformer3DModel(**kwargs) + + +def get_wan22_post_process_func( + od_config: OmniDiffusionConfig, +): + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def post_process_func( + video: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return video + return video_processor.postprocess_video(video, output_type=output_type) + + return post_process_func + + +def get_wan22_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-process function for Wan2.2: optionally load and resize input image for I2V mode.""" + import numpy as np + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + continue + + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"""Unsupported image format {raw_image.__class__}.""", + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""", + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 720P + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # Above has ensured that width & height are not None + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above + + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request + + return pre_process_func + + +class Wan22Pipeline(nn.Module, CFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + + self.device = get_local_device() + dtype = getattr(od_config, "dtype", torch.bfloat16) + + model = od_config.model + local_files_only = os.path.exists(model) + + # Read model_index.json to detect expand_timesteps mode (for TI2V-5B) + self.expand_timesteps = False + self.has_transformer_2 = False + if local_files_only: + model_index_path = os.path.join(model, "model_index.json") + if os.path.exists(model_index_path): + with open(model_index_path) as f: + model_index = json.load(f) + self.expand_timesteps = model_index.get("expand_timesteps", False) + # Check if this is a two-stage model (MoE with transformer_2) + transformer_2_path = os.path.join(model, "transformer_2") + self.has_transformer_2 = os.path.exists(transformer_2_path) + else: + # For remote models, download and read model_index.json + try: + from huggingface_hub import hf_hub_download + + model_index_path = hf_hub_download(repo_id=model, filename="model_index.json") + with open(model_index_path) as f: + model_index = json.load(f) + self.expand_timesteps = model_index.get("expand_timesteps", False) + # Check transformer_2 from model_index + transformer_2_info = model_index.get("transformer_2", [None, None]) + self.has_transformer_2 = transformer_2_info[0] is not None + except Exception: + pass + + self.boundary_ratio = od_config.boundary_ratio + + # Determine which transformers to load based on boundary_ratio + # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) + # boundary_ratio=0.0: only load transformer (high-noise stage only) + # otherwise: load both transformers + load_transformer = self.boundary_ratio != 1.0 if self.boundary_ratio is not None else True + load_transformer_2 = self.has_transformer_2 and ( + self.boundary_ratio != 0.0 if self.boundary_ratio is not None else True + ) + + # Set up weights sources for transformer(s) + self.weights_sources = [] + if load_transformer: + self.weights_sources.append( + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ) + if load_transformer_2: + self.weights_sources.append( + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer_2", + revision=None, + prefix="transformer_2.", + fall_back_to_pt=True, + ) + ) + + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.text_encoder = UMT5EncoderModel.from_pretrained( + model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only + ).to(self.device) + self.vae = AutoencoderKLWan.from_pretrained( + model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only + ).to(self.device) + + # Initialize transformers with correct config (weights loaded via load_weights) + if load_transformer: + transformer_config = load_transformer_config(model, "transformer", local_files_only) + self.transformer = create_transformer_from_config(transformer_config) + else: + self.transformer = None + + if load_transformer_2: + transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) + self.transformer_2 = create_transformer_from_config(transformer_2_config) + else: + self.transformer_2 = None + + # Store the active transformer config + if load_transformer: + self.transformer_config = self.transformer.config + elif load_transformer_2: + self.transformer_config = self.transformer_2.config + else: + raise RuntimeError("No transformer loaded") + + # Initialize UniPC scheduler + flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self._guidance_scale = None + self._guidance_scale_2 = None + self._num_timesteps = None + self._current_timestep = None + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | None = None, + negative_prompt: str | None = None, + height: int = 480, + width: int = 832, + num_inference_steps: int = 40, + guidance_scale: float | tuple[float, float] = 4.0, + frame_num: int = 81, + output_type: str | None = "np", + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_kwargs: dict | None = None, + **kwargs, + ) -> DiffusionOutput: + # Get parameters from request or arguments + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list + prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") + negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + if prompt is None and prompt_embeds is None: + raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num + + # Ensure dimensions are compatible with VAE and patch size + # For expand_timesteps mode, we need latent dims to be even (divisible by patch_size) + patch_size = self.transformer_config.patch_size + mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V + height = (height // mod_value) * mod_value + width = (width // mod_value) * mod_value + num_steps = req.sampling_params.num_inference_steps or num_inference_steps + + # Respect per-request guidance_scale when explicitly provided. + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + + guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] + guidance_high = ( + req.sampling_params.guidance_scale_2 + if req.sampling_params.guidance_scale_2 is not None + else ( + guidance_scale[1] + if isinstance(guidance_scale, (list, tuple)) and len(guidance_scale) > 1 + else guidance_low + ) + ) + + # record guidance for properties + self._guidance_scale = guidance_low + self._guidance_scale_2 = guidance_high + + # validate shapes + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale_2=guidance_high if self.boundary_ratio is not None else None, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + device = self.device + # Get dtype from whichever transformer is loaded + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.transformer_2 is not None: + dtype = self.transformer_2.dtype + else: + # Fallback to text_encoder dtype if no transformer loaded + dtype = self.text_encoder.dtype + + # Seed / generator + if generator is None: + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) + + # Encode prompts + if prompt_embeds is None: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, + num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, + max_sequence_length=req.sampling_params.max_sequence_length or 512, + device=device, + dtype=dtype, + ) + else: + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype) + elif guidance_low > 1.0 or guidance_high > 1.0: + raise ValueError( + "negative_prompt_embeds must be provided when prompt_embeds are given and guidance > 1." + ) + + # Timesteps + self.scheduler.set_timesteps(num_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + boundary_timestep = None + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + + # Handle I2V mode when expand_timesteps=True and image is provided + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if raw_image is None: + image = None + elif isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) + + latent_condition = None + first_frame_mask = None + + if self.expand_timesteps and image is not None: + # I2V mode: encode image and prepare condition + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Preprocess image + if isinstance(image, PIL.Image.Image): + image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + image_tensor = video_processor.preprocess(image, height=height, width=width) + else: + image_tensor = image + + # Use out_channels for noise latents (not in_channels which includes condition) + num_channels_latents = self.transformer_config.out_channels + batch_size = prompt_embeds.shape[0] + + # Prepare noise latents + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=req.sampling_params.latents, + ) + + # Encode image condition + num_latent_frames = latents.shape[2] + latent_height = latents.shape[3] + latent_width = latents.shape[4] + + image_tensor = image_tensor.unsqueeze(2) # [B, C, 1, H, W] + image_tensor = image_tensor.to(device=device, dtype=self.vae.dtype) + latent_condition = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + # Normalize condition latents + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latent_condition.device, latent_condition.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latent_condition.device, latent_condition.dtype + ) + latent_condition = (latent_condition - latents_mean) * latents_std + latent_condition = latent_condition.to(torch.float32) + + # Create mask: 0 for first frame (condition), 1 for rest (to denoise) + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=torch.float32, device=device + ) + first_frame_mask[:, :, 0] = 0 + else: + # T2V mode: standard latent preparation + num_channels_latents = self.transformer_config.in_channels + latents = self.prepare_latents( + batch_size=prompt_embeds.shape[0], + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=req.sampling_params.latents, + ) + + if attention_kwargs is None: + attention_kwargs = {} + + # Denoising + for t in timesteps: + self._current_timestep = t + + # Select model based on timestep and boundary_ratio + # High noise stage (t >= boundary_timestep): use transformer + # Low noise stage (t < boundary_timestep): use transformer_2 + if boundary_timestep is not None and t < boundary_timestep: + # Low noise stage - always use guidance_high for this stage + current_guidance_scale = guidance_high + if self.transformer_2 is not None: + current_model = self.transformer_2 + elif self.transformer is not None: + # Fallback to transformer if transformer_2 not loaded + current_model = self.transformer + else: + raise RuntimeError("No transformer available for low-noise stage") + else: + # High noise stage - always use guidance_low for this stage + current_guidance_scale = guidance_low + if self.transformer is not None: + current_model = self.transformer + elif self.transformer_2 is not None: + # Fallback to transformer_2 if transformer not loaded + current_model = self.transformer_2 + else: + raise RuntimeError("No transformer available for high-noise stage") + + if self.expand_timesteps and latent_condition is not None: + # I2V mode: blend condition with latents using mask + latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(dtype) + + # Expand timesteps per patch - use floor division to match patch embedding + patch_size = self.transformer_config.patch_size + num_latent_frames = latents.shape[2] + patch_height = latents.shape[3] // patch_size[1] + patch_width = latents.shape[4] // patch_size[2] + + # Create mask at patch resolution (same as hidden states sequence length) + patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]] + patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions + temp_ts = (patch_mask[0][0] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # T2V mode: standard forward + latent_model_input = latents.to(dtype) + timestep = t.expand(latents.shape[0]) + + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() + self._current_timestep = None + + # For I2V mode: blend final latents with condition + if self.expand_timesteps and latent_condition is not None: + latents = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents + + # Decode + if output_type == "latent": + output = latents + else: + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + output = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=output) + + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use (transformer or transformer_2) + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_clean = [self._prompt_clean(p) for p in prompt] + batch_size = len(prompt_clean) + + text_inputs = self.tokenizer( + prompt_clean, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + neg_text_inputs = self.tokenizer( + [self._prompt_clean(p) for p in negative_prompt], + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids_neg, mask_neg = neg_text_inputs.input_ids, neg_text_inputs.attention_mask + seq_lens_neg = mask_neg.gt(0).sum(dim=1).long() + negative_prompt_embeds = self.text_encoder(ids_neg.to(device), mask_neg.to(device)).last_hidden_state + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] + negative_prompt_embeds = torch.stack( + [ + torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) + for u in negative_prompt_embeds + ], + dim=0, + ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + @staticmethod + def _prompt_clean(text: str) -> str: + return " ".join(text.strip().split()) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: torch.dtype | None, + device: torch.device | None, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError(f"Generator list length {len(generator)} does not match batch size {batch_size}.") + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + guidance_scale_2=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and " + f"`negative_prompt_embeds`: {negative_prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..1aed9b75de3c60db1f73c201d957e0f6ea25e632 --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -0,0 +1,782 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + create_transformer_from_config, + load_transformer_config, + retrieve_latents, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform + +logger = logging.getLogger(__name__) + + +def _load_model_index(model: str, local_files_only: bool) -> dict: + """Load model_index.json from local path or HF Hub.""" + if local_files_only: + model_index_path = os.path.join(model, "model_index.json") + if os.path.exists(model_index_path): + import json + + with open(model_index_path) as f: + return json.load(f) + else: + try: + import json + + from huggingface_hub import hf_hub_download + + model_index_path = hf_hub_download(model, "model_index.json") + with open(model_index_path) as f: + return json.load(f) + except Exception: + pass + return {} + + +def get_wan22_i2v_post_process_func( + od_config: OmniDiffusionConfig, +): + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def post_process_func( + video: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return video + return video_processor.postprocess_video(video, output_type=output_type) + + return post_process_func + + +def get_wan22_i2v_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-process function for I2V: load and resize input image.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + """No image is provided. This model requires an image to run.""", + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""", + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"""Unsupported image format {raw_image.__class__}.""", + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""", + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 480P + max_area = 480 * 832 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # Above has ensured that width & height are not None + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above + + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request + + return pre_process_func + + +class Wan22I2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): + """ + Wan2.2 Image-to-Video Pipeline. + + Supports both Wan2.1-style I2V (with CLIP image embeddings) and + Wan2.2-style I2V (with expand_timesteps for TI2V-5B). + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + + self.device = get_local_device() + dtype = getattr(od_config, "dtype", torch.bfloat16) + + model = od_config.model + local_files_only = os.path.exists(model) + + # Set up weights sources for transformer(s) + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + ] + + # Load model_index.json to detect available components + model_index = _load_model_index(model, local_files_only) + + # Check if this is a two-stage model (MoE with transformer_2) + self.has_transformer_2 = "transformer_2" in model_index + + if self.has_transformer_2: + self.weights_sources.append( + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer_2", + revision=None, + prefix="transformer_2.", + fall_back_to_pt=True, + ) + ) + + # Text encoder + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.text_encoder = UMT5EncoderModel.from_pretrained( + model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only + ).to(self.device) + + # Image encoder (CLIP) - optional, for Wan2.1-style I2V + self.has_image_encoder = "image_encoder" in model_index and model_index["image_encoder"][0] is not None + + if self.has_image_encoder: + self.image_processor = CLIPImageProcessor.from_pretrained( + model, subfolder="image_processor", local_files_only=local_files_only + ) + self.image_encoder = CLIPVisionModel.from_pretrained( + model, subfolder="image_encoder", torch_dtype=dtype, local_files_only=local_files_only + ).to(self.device) + else: + self.image_processor = None + self.image_encoder = None + + # VAE + self.vae = AutoencoderKLWan.from_pretrained( + model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only + ).to(self.device) + + # Transformers (weights loaded via load_weights) + # Load config from model directory or HF Hub to get correct in_channels for I2V models + transformer_config = load_transformer_config(model, "transformer", local_files_only) + self.transformer = create_transformer_from_config(transformer_config) + if self.has_transformer_2: + transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) + self.transformer_2 = create_transformer_from_config(transformer_2_config) + else: + self.transformer_2 = None + + # Initialize UniPC scheduler + flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", + ) + + # VAE scale factors + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if hasattr(self.vae, "config") else 8 + + # MoE boundary ratio for two-stage denoising + self.boundary_ratio = od_config.boundary_ratio + + # Whether to use expand_timesteps mode (for TI2V-5B style) + self.expand_timesteps = getattr(od_config, "expand_timesteps", False) + + self._guidance_scale = None + self._guidance_scale_2 = None + self._num_timesteps = None + self._current_timestep = None + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + def encode_image( + self, + image: PIL.Image.Image | list[PIL.Image.Image], + device: torch.device | None = None, + ) -> torch.Tensor: + """Encode image using CLIP image encoder.""" + device = device or self.device + if self.image_encoder is None: + raise ValueError("Image encoder not available for this model.") + + pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=device, dtype=self.image_encoder.dtype) + image_embeds = self.image_encoder(pixel_values, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | None = None, + negative_prompt: str | None = None, + image: PIL.Image.Image | torch.Tensor | None = None, + height: int = 480, + width: int = 832, + num_inference_steps: int = 40, + guidance_scale: float | tuple[float, float] = 5.0, + frame_num: int = 81, + output_type: str | None = "np", + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + last_image: PIL.Image.Image | torch.Tensor | None = None, + attention_kwargs: dict | None = None, + **kwargs, + ) -> DiffusionOutput: + # Get parameters from request or arguments + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list + prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") + negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + if prompt is None and prompt_embeds is None: + raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") + + # Get image from request + if image is None: + multi_modal_data = ( + req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + ) + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if raw_image is None: + raise ValueError("Image is required for I2V generation.") + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_frames = req.sampling_params.num_frames or frame_num + num_steps = req.sampling_params.num_inference_steps or num_inference_steps + + # Respect per-request guidance_scale when explicitly provided. + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + + # Handle guidance scales + guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] + guidance_high = ( + req.sampling_params.guidance_scale_2 + if req.sampling_params.guidance_scale_2 is not None + else ( + guidance_scale[1] + if isinstance(guidance_scale, (list, tuple)) and len(guidance_scale) > 1 + else guidance_low + ) + ) + + self._guidance_scale = guidance_low + self._guidance_scale_2 = guidance_high + + # Validate inputs + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + height=height, + width=width, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + guidance_scale_2=guidance_high if self.boundary_ratio is not None else None, + ) + + # Adjust num_frames to be compatible with VAE temporal scaling + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + device = self.device + dtype = self.transformer.dtype + + # Generator setup + if generator is None: + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) + + # Encode prompts + if prompt_embeds is None: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, + num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, + max_sequence_length=req.sampling_params.max_sequence_length or 512, + device=device, + dtype=dtype, + ) + else: + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype) + + batch_size = prompt_embeds.shape[0] + + # Encode image embeddings (for Wan2.1-style with CLIP) + if self.has_image_encoder and self.transformer.config.image_dim is not None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(dtype) + else: + image_embeds = None + + # Timesteps + self.scheduler.set_timesteps(num_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + boundary_timestep = None + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + + # Prepare latents (use out_channels=16 for VAE latent, not in_channels=36) + num_channels_latents = self.transformer.config.out_channels + + # Preprocess image for VAE + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + if isinstance(image, PIL.Image.Image): + image_tensor = video_processor.preprocess(image, height=height, width=width) + else: + image_tensor = image + image_tensor = image_tensor.to(device=device, dtype=torch.float32) + + # Handle last_image if provided + if last_image is not None: + if isinstance(last_image, PIL.Image.Image): + last_image_tensor = video_processor.preprocess(last_image, height=height, width=width) + else: + last_image_tensor = last_image + last_image_tensor = last_image_tensor.to(device=device, dtype=torch.float32) + else: + last_image_tensor = None + + latents, condition, first_frame_mask = self.prepare_latents( + image=image_tensor, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=req.sampling_params.latents, + last_image=last_image_tensor, + ) + + if attention_kwargs is None: + attention_kwargs = {} + + # Denoising loop + for t in timesteps: + self._current_timestep = t + + # Select model and guidance scale based on timestep + current_model = self.transformer + current_guidance_scale = guidance_low + if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None: + current_model = self.transformer_2 + current_guidance_scale = guidance_high + + # Prepare latent input + if self.expand_timesteps: + # TI2V-5B style: blend condition with latents using mask + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(dtype) + + # Expand timesteps for each patch + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # Wan2.1 style: concatenate condition with latents + latent_model_input = torch.cat([latents, condition], dim=1).to(dtype) + timestep = t.expand(latents.shape[0]) + + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() + self._current_timestep = None + + # For expand_timesteps mode, blend final latents with condition + if self.expand_timesteps: + latents = (1 - first_frame_mask) * condition + first_frame_mask * latents + + # Decode + if output_type == "latent": + output = latents + else: + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + output = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=output) + + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use (transformer or transformer_2) + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """Encode text prompts using T5 text encoder.""" + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_clean = [self._prompt_clean(p) for p in prompt] + batch_size = len(prompt_clean) + + text_inputs = self.tokenizer( + prompt_clean, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + neg_text_inputs = self.tokenizer( + [self._prompt_clean(p) for p in negative_prompt], + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids_neg, mask_neg = neg_text_inputs.input_ids, neg_text_inputs.attention_mask + seq_lens_neg = mask_neg.gt(0).sum(dim=1).long() + negative_prompt_embeds = self.text_encoder(ids_neg.to(device), mask_neg.to(device)).last_hidden_state + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] + negative_prompt_embeds = torch.stack( + [ + torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) + for u in negative_prompt_embeds + ], + dim=0, + ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + @staticmethod + def _prompt_clean(text: str) -> str: + return " ".join(text.strip().split()) + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: torch.dtype | None, + device: torch.device | None, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare latents for I2V generation. + + Returns: + latents: Initial noise latents + condition: Encoded image condition (concatenated with mask for non-expand mode) + first_frame_mask: Mask for the first frame (1 for frames to denoise, 0 for condition) + """ + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # Prepare image condition + image = image.unsqueeze(2) # [batch, channels, 1, height, width] + + if self.expand_timesteps: + # TI2V-5B style: only use first frame as condition + video_condition = image + elif last_image is None: + # Pad with zeros for remaining frames + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + # First and last frame conditioning + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + # Encode through VAE + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + # Normalize latents + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latent_condition.device, latent_condition.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latent_condition.device, latent_condition.dtype + ) + latent_condition = (latent_condition - latents_mean) * latents_std + latent_condition = latent_condition.to(dtype) + + if self.expand_timesteps: + # TI2V-5B style: create mask where first frame is 0 (condition), rest is 1 (to denoise) + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device + ) + first_frame_mask[:, :, 0] = 0 + return latents, latent_condition, first_frame_mask + + # Wan2.1 style: create mask and concatenate with condition + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + # Concatenate mask with condition for channel dimension + condition = torch.concat([mask_lat_size, latent_condition], dim=1) + + # For non-expand mode, first_frame_mask is not used in the same way + first_frame_mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device) + + return latents, condition, first_frame_mask + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + guidance_scale_2=None, + ): + if image is None and image_embeds is None: + raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") + + if image is not None and image_embeds is not None: + raise ValueError("Cannot forward both `image` and `image_embeds`. Please provide only one.") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`. Please provide only one.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Cannot forward both `negative_prompt` and `negative_prompt_embeds`. Please provide only one." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py new file mode 100644 index 0000000000000000000000000000000000000000..d32b7d697cd663d5a65f6c8bdb68c1aa29de0e1a --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Wan2.2 TI2V (Text-Image-to-Video) Pipeline. + +This pipeline supports the unified TI2V-5B model that can generate videos from: +- Text only (T2V mode) +- Text + Image (I2V mode) + +The key difference from the MoE-based I2V pipeline is: +- Single transformer (not MoE with two transformers) +- Uses expand_timesteps mode for image conditioning +- No CLIP image encoder - only VAE encoding for image condition +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, UMT5EncoderModel +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + create_transformer_from_config, + load_transformer_config, + retrieve_latents, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform + +logger = logging.getLogger(__name__) + + +def get_wan22_ti2v_post_process_func( + od_config: OmniDiffusionConfig, +): + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def post_process_func( + video: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return video + return video_processor.postprocess_video(video, output_type=output_type) + + return post_process_func + + +def get_wan22_ti2v_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-process function for TI2V: optionally load and resize input image.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + """No image is provided. This model requires an image to run.""", + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""", + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"""Unsupported image format {raw_image.__class__}.""", + """Please correctly set `"multi_modal_data": {"image": <an image object or file path>, …}`""", + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 720P (TI2V-5B default) + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # Above has ensured that width & height are not None + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above + + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request + + return pre_process_func + + +class Wan22TI2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): + """ + Wan2.2 Text-Image-to-Video (TI2V) Pipeline. + + This is a unified pipeline that supports both: + - Text-to-Video (T2V): when no image is provided + - Image-to-Video (I2V): when an image is provided + + Uses expand_timesteps mode for I2V conditioning where the first frame + is conditioned on the input image latent. + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + + self.device = get_local_device() + dtype = getattr(od_config, "dtype", torch.bfloat16) + + model = od_config.model + local_files_only = os.path.exists(model) + + # Set up weights sources for single transformer + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + ] + + # Text encoder + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.text_encoder = UMT5EncoderModel.from_pretrained( + model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only + ).to(self.device) + + # VAE + self.vae = AutoencoderKLWan.from_pretrained( + model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only + ).to(self.device) + + # Single transformer (TI2V uses dense 5B model, not MoE) + # Load config from model to get correct dimensions + transformer_config = load_transformer_config(model, "transformer", local_files_only) + self.transformer = create_transformer_from_config(transformer_config) + + # Initialize UniPC scheduler + flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", + ) + + # VAE scale factors + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if hasattr(self.vae, "config") else 8 + + # TI2V always uses expand_timesteps mode + self.expand_timesteps = True + + self._guidance_scale = None + self._num_timesteps = None + self._current_timestep = None + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | None = None, + negative_prompt: str | None = None, + image: PIL.Image.Image | torch.Tensor | None = None, + height: int = 704, + width: int = 1280, + num_inference_steps: int = 40, + guidance_scale: float = 5.0, + frame_num: int = 81, + output_type: str | None = "np", + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_kwargs: dict | None = None, + **kwargs, + ) -> DiffusionOutput: + # Get parameters from request or arguments + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list + prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") + negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + if prompt is None and prompt_embeds is None: + raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") + + # Get image from request (optional for TI2V) + if image is None: + multi_modal_data = ( + req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + ) + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if raw_image is None: + image = None + elif isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) + + # Default dimensions for TI2V-5B (720P) + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num + num_steps = req.sampling_params.num_inference_steps or num_inference_steps + + # Respect per-request guidance_scale when explicitly provided. + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + + self._guidance_scale = guidance_scale + + # Validate inputs + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + height=height, + width=width, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # Adjust num_frames to be compatible with VAE temporal scaling + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + device = self.device + dtype = self.transformer.dtype + + # Generator setup + if generator is None: + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) + + # Encode prompts + if prompt_embeds is None: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=guidance_scale > 1.0, + num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, + max_sequence_length=req.sampling_params.max_sequence_length or 512, + device=device, + dtype=dtype, + ) + else: + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype) + + batch_size = prompt_embeds.shape[0] + + # Timesteps + self.scheduler.set_timesteps(num_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # Prepare latents + num_channels_latents = self.transformer.config.in_channels + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + # Check if we have an image (I2V mode) or not (T2V mode) + if image is not None: + # I2V mode: prepare latents with image condition + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + if isinstance(image, PIL.Image.Image): + image_tensor = video_processor.preprocess(image, height=height, width=width) + else: + image_tensor = image + image_tensor = image_tensor.to(device=device, dtype=torch.float32) + + latents, latent_condition, first_frame_mask = self.prepare_i2v_latents( + image=image_tensor, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=req.sampling_params.latents, + ) + else: + # T2V mode: prepare random latents + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=req.sampling_params.latents, + ) + latent_condition = None + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=torch.float32, device=device + ) + + if attention_kwargs is None: + attention_kwargs = {} + + # Denoising loop + for t in timesteps: + self._current_timestep = t + + # Prepare latent input + if latent_condition is not None: + # I2V mode: blend condition with latents using mask + latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(dtype) + + # Expand timesteps for each patch (TI2V style) + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # T2V mode: use latents directly + latent_model_input = latents.to(dtype) + + # Expand timesteps for TI2V model architecture + mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=device) + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + + do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() + self._current_timestep = None + + # For I2V mode, blend final latents with condition + if latent_condition is not None: + latents = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents + + # Decode + if output_type == "latent": + output = latents + else: + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + output = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=output) + + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """Encode text prompts using T5 text encoder.""" + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_clean = [self._prompt_clean(p) for p in prompt] + batch_size = len(prompt_clean) + + text_inputs = self.tokenizer( + prompt_clean, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + neg_text_inputs = self.tokenizer( + [self._prompt_clean(p) for p in negative_prompt], + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids_neg, mask_neg = neg_text_inputs.input_ids, neg_text_inputs.attention_mask + seq_lens_neg = mask_neg.gt(0).sum(dim=1).long() + negative_prompt_embeds = self.text_encoder(ids_neg.to(device), mask_neg.to(device)).last_hidden_state + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] + negative_prompt_embeds = torch.stack( + [ + torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) + for u in negative_prompt_embeds + ], + dim=0, + ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + @staticmethod + def _prompt_clean(text: str) -> str: + return " ".join(text.strip().split()) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: torch.dtype | None, + device: torch.device | None, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """Prepare random latents for T2V mode.""" + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError(f"Generator list length {len(generator)} does not match batch size {batch_size}.") + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_i2v_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: torch.dtype | None, + device: torch.device | None, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare latents for I2V mode with image conditioning. + + Returns: + latents: Initial noise latents + latent_condition: Encoded first frame condition + first_frame_mask: Mask (0 for first frame, 1 for rest) + """ + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # Prepare first frame condition + image = image.unsqueeze(2) # [batch, channels, 1, height, width] + image = image.to(device=device, dtype=self.vae.dtype) + + # Encode through VAE + latent_condition = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + # Normalize latents + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latent_condition.device, latent_condition.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latent_condition.device, latent_condition.dtype + ) + latent_condition = (latent_condition - latents_mean) * latents_std + latent_condition = latent_condition.to(dtype) + + # Create mask: 0 for first frame (condition), 1 for rest (to denoise) + first_frame_mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device) + first_frame_mask[:, :, 0] = 0 + + return latents, latent_condition, first_frame_mask + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`. Please provide only one.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Cannot forward both `negative_prompt` and `negative_prompt_embeds`. Please provide only one." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab92ad0c8821636ffff9a7f29bc2801712ed7ff7 --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -0,0 +1,756 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm +from vllm.logger import init_logger +from vllm.model_executor.layers.conv import Conv3dLayer +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, +) + +logger = init_logger(__name__) + + +def apply_rotary_emb_wan( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensors. + + Args: + hidden_states: Input tensor of shape [B, S, H, D] + freqs_cos: Cosine frequencies + freqs_sin: Sine frequencies + + Returns: + Tensor with rotary embeddings applied + """ + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + +class WanRotaryPosEmbed(nn.Module): + """ + Rotary position embeddings for 3D video data (temporal + spatial dimensions). + """ + + def __init__( + self, + attention_head_dim: int, + patch_size: tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + # Split dimensions for temporal, height, width + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = self._get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + @staticmethod + def _get_1d_rotary_pos_embed( + dim: int, + max_seq_len: int, + theta: float, + freqs_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Generate 1D rotary position embeddings.""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype) / dim)) + t = torch.arange(max_seq_len, dtype=freqs_dtype) + freqs = torch.outer(t, freqs) + # Repeat interleave for real representation + freqs_cos = freqs.cos().repeat_interleave(2, dim=-1) + freqs_sin = freqs.sin().repeat_interleave(2, dim=-1) + return freqs_cos.float(), freqs_sin.float() + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + + return freqs_cos, freqs_sin + + +class WanImageEmbedding(nn.Module): + """Image embedding module for I2V tasks.""" + + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len: int | None = None): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class WanTimeTextImageEmbedding(nn.Module): + """Combined time, text, and image condition embeddings.""" + + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: int | None = None, + pos_embed_seq_len: int | None = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + timestep_seq_len: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class WanSelfAttention(nn.Module): + """ + Optimized self-attention module using vLLM layers. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + eps: float = 1e-5, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + + # Fused QKV projection using vLLM's optimized layer + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=num_heads, + bias=True, + disable_tp=True, + ) + + # QK normalization using vLLM's RMSNorm + self.norm_q = RMSNorm(self.inner_dim, eps=eps) + self.norm_k = RMSNorm(self.inner_dim, eps=eps) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + + # Unified attention layer + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Fused QKV projection + qkv, _ = self.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) + + # Apply QK normalization + query = self.norm_q(query) + key = self.norm_k(key) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)) + key = key.unflatten(2, (self.num_heads, -1)) + value = value.unflatten(2, (self.num_heads, -1)) + + # Apply rotary embeddings + if rotary_emb is not None: + freqs_cos, freqs_sin = rotary_emb + query = apply_rotary_emb_wan(query, freqs_cos, freqs_sin) + key = apply_rotary_emb_wan(key, freqs_cos, freqs_sin) + + # Compute attention using unified attention layer + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class WanCrossAttention(nn.Module): + """ + Optimized cross-attention module using vLLM layers. + Handles both text cross-attention and optional image cross-attention (I2V). + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: int | None = None, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + self.kv_inner_dim = head_dim * num_heads # For cross-attention, K/V come from encoder + + # Query projection + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=True) + + # Separate K and V projections for cross-attention + self.to_k = ReplicatedLinear(dim, self.kv_inner_dim, bias=True) + self.to_v = ReplicatedLinear(dim, self.kv_inner_dim, bias=True) + + # QK normalization + self.norm_q = RMSNorm(self.inner_dim, eps=eps) + self.norm_k = RMSNorm(self.kv_inner_dim, eps=eps) + + # Optional added KV projections for I2V (image embeddings) + self.added_kv_proj_dim = added_kv_proj_dim + if added_kv_proj_dim is not None: + self.add_k_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = RMSNorm(self.inner_dim, eps=eps) + else: + self.add_k_proj = None + self.add_v_proj = None + self.norm_added_k = None + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + + # Unified attention layer + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Handle I2V case where encoder_hidden_states contains both image and text + encoder_hidden_states_img = None + if self.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + # Query projection + query, _ = self.to_q(hidden_states) + query = self.norm_q(query) + + # KV projection from encoder + key, _ = self.to_k(encoder_hidden_states) + value, _ = self.to_v(encoder_hidden_states) + key = self.norm_k(key) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)) + key = key.unflatten(2, (self.num_heads, -1)) + value = value.unflatten(2, (self.num_heads, -1)) + + # I2V: Additional attention with image embeddings + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, _ = self.add_k_proj(encoder_hidden_states_img) + value_img, _ = self.add_v_proj(encoder_hidden_states_img) + key_img = self.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (self.num_heads, -1)) + value_img = value_img.unflatten(2, (self.num_heads, -1)) + + hidden_states_img = self.attn(query, key_img, value_img) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + # Main cross-attention using unified attention layer + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Add image attention output if present + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class WanTransformerBlock(nn.Module): + """ + Transformer block for Wan model with self-attention, cross-attention, and FFN. + Uses scale-shift modulation from timestep embeddings. + """ + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + cross_attn_norm: bool = False, + ): + super().__init__() + + head_dim = dim // num_heads + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanSelfAttention( + dim=dim, + num_heads=num_heads, + head_dim=head_dim, + eps=eps, + ) + + # 2. Cross-attention + self.attn2 = WanCrossAttention( + dim=dim, + num_heads=num_heads, + head_dim=head_dim, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + # Scale-shift table for modulation + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(norm_hidden_states, rotary_emb) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class WanTransformer3DModel(nn.Module): + """ + Optimized Wan Transformer model for video generation using vLLM layers. + + This is an optimized version of the diffusers WanTransformer3DModel that uses + vLLM's efficient QKVParallelLinear and RMSNorm implementations. + + Sequence Parallelism: + This model supports non-intrusive SP via _sp_plan. The plan specifies: + - RoPE (cos/sin) splitting via rope module's split_output + - hidden_states splitting at first transformer block input + - Output gathering at proj_out layer + + The video sequence (flattened patches) is parallelized across GPUs. + + Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers. + + Args: + patch_size: 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + num_attention_heads: Number of attention heads + attention_head_dim: Dimension of each attention head + in_channels: Number of input channels + out_channels: Number of output channels + text_dim: Input dimension for text embeddings + freq_dim: Dimension for sinusoidal time embeddings + ffn_dim: Intermediate dimension in feed-forward network + num_layers: Number of transformer blocks + cross_attn_norm: Enable cross-attention normalization + eps: Epsilon value for normalization layers + image_dim: Optional image embedding dimension for I2V + added_kv_proj_dim: Optional added KV projection dimension for I2V + rope_max_seq_len: Maximum sequence length for rotary embeddings + pos_embed_seq_len: Optional position embedding sequence length + """ + + _repeated_blocks = ["WanTransformerBlock"] + _layerwise_offload_blocks_attr = "blocks" + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + } + + # Sequence Parallelism for Wan (following diffusers' _cp_plan pattern) + # + # The _sp_plan specifies sharding/gathering at module boundaries: + # - rope: Split both RoPE outputs (freqs_cos, freqs_sin) via split_output=True + # - blocks.0: Split hidden_states input at the first transformer block + # - proj_out: Gather outputs after the final projection layer + # + # Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism) + _sp_plan = { + # Shard RoPE embeddings after rope module computes them + "rope": { + 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # freqs_cos [1, seq, 1, dim] + 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # freqs_sin [1, seq, 1, dim] + }, + # Shard hidden_states at first transformer block input + # (after patch_embedding + flatten + transpose) + "blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), # [B, seq, dim] + }, + # Gather at proj_out (final linear projection before unpatchify) + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + def __init__( + self, + patch_size: tuple[int, int, int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + eps: float = 1e-6, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: int | None = None, + ): + super().__init__() + + # Store config for compatibility + self.config = type( + "Config", + (), + { + "patch_size": patch_size, + "num_attention_heads": num_attention_heads, + "attention_head_dim": attention_head_dim, + "in_channels": in_channels, + "out_channels": out_channels, + "text_dim": text_dim, + "freq_dim": freq_dim, + "ffn_dim": ffn_dim, + "num_layers": num_layers, + "cross_attn_norm": cross_attn_norm, + "eps": eps, + "image_dim": image_dim, + "added_kv_proj_dim": added_kv_proj_dim, + "rope_max_seq_len": rope_max_seq_len, + "pos_embed_seq_len": pos_embed_seq_len, + }, + )() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = Conv3dLayer( + in_channels=in_channels, + out_channels=inner_dim, + kernel_size=patch_size, + stride=patch_size, + ) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, eps, added_kv_proj_dim, cross_attn_norm) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + @property + def dtype(self) -> torch.dtype: + """Return the dtype of the model parameters.""" + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor | Transformer2DModelOutput: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Compute RoPE embeddings (sharded by _sp_plan via split_output=True) + rotary_emb = self.rope(hidden_states) + + # Patch embedding and flatten to sequence + # (hidden_states is sharded at blocks.0 input by _sp_plan) + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # Handle timestep shape + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + ) + if ts_seq_len is not None: + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # Transformer blocks + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Output norm, projection & unpatchify + if temb.ndim == 3: + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from a pretrained model, handling the mapping from + separate Q/K/V projections to fused QKV projections for self-attention. + + Diffusers weight names: + - blocks.N.attn1.to_q/to_k/to_v -> fused to blocks.N.attn1.to_qkv (self-attention) + - blocks.N.attn2.to_q/to_k/to_v -> kept separate (cross-attention) + - blocks.N.attn1.norm_q/norm_k -> QK normalization for self-attention + + Returns: + Set of parameter names that were successfully loaded. + """ + # Stacked params mapping for self-attention QKV fusion + # Format: (param_name, shard_name, shard_id) + # Note: Only fuse attn1 (self-attention), NOT attn2 (cross-attention) + stacked_params_mapping = [ + # self-attention QKV fusion (attn1 only) + (".attn1.to_qkv", ".attn1.to_q", "q"), + (".attn1.to_qkv", ".attn1.to_k", "k"), + (".attn1.to_qkv", ".attn1.to_v", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/vllm_omni/diffusion/models/z_image/__init__.py b/vllm_omni/diffusion/models/z_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..62d92900a251fc126a860c0ffa80c360c0b1a4a7 --- /dev/null +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -0,0 +1,629 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +from collections.abc import Callable, Iterable +from typing import Any + +import torch +import torch.nn as nn +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoModel, AutoTokenizer +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.z_image.z_image_transformer import ( + ZImageTransformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImagePipeline(nn.Module): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + self._execution_device = get_local_device() + model = od_config.model + local_files_only = os.path.exists(model) + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + + self.text_encoder = AutoModel.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self._execution_device + ) + self.transformer = ZImageTransformer2DModel() + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + + # Note: Context parallelism is applied centrally in registry.initialize_model() + # following diffusers' pattern of enable_parallelism() at model loading time + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + max_sequence_length: int = 512, + ) -> list[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ) -> DiffusionOutput: + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + generator = req.sampling_params.generator + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_rescale is not None else guidance_scale + ) + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + # image = self.image_processor.postprocess(image, output_type=output_type) + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7d4eb5b4fffe4325adae0aa68ff0f437c51228 --- /dev/null +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -0,0 +1,941 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# _sp_plan definition adapted from HuggingFace diffusers library (_cp_plan) + +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.cache.base import CachedTransformer +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, +) +from vllm_omni.diffusion.forward_context import ( + get_forward_context, + is_forward_context_available, +) +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + +logger = init_logger(__name__) + + +class UnifiedPrepare(nn.Module): + """Prepares unified tensors for transformer blocks. + + This module encapsulates the unification of x and cap tensors into unified + sequences. Similar to how Wan's `rope` module outputs rotary embeddings, + this module outputs unified tensors that can be sharded via _sp_plan's + split_output=True mechanism. + + This follows the diffusers pattern where tensor preparation happens in + a dedicated submodule, enabling _sp_plan hooks to work at module boundaries. + """ + + def forward( + self, + x: torch.Tensor, + x_cos: torch.Tensor, + x_sin: torch.Tensor, + cap_feats: torch.Tensor, + cap_cos: torch.Tensor, + cap_sin: torch.Tensor, + x_item_seqlens: list[int], + cap_item_seqlens: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Combine x and cap tensors into unified sequences. + + Returns: + unified: Combined hidden states [batch, seq_len, dim] + unified_cos: Combined RoPE cos [batch, seq_len, rope_dim] + unified_sin: Combined RoPE sin [batch, seq_len, rope_dim] + unified_attn_mask: Combined attention mask [batch, seq_len] + """ + bsz = x.shape[0] + device = x.device + + unified = [] + unified_cos = [] + unified_sin = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_cos.append(torch.cat([x_cos[i][:x_len], cap_cos[i][:cap_len]])) + unified_sin.append(torch.cat([x_sin[i][:x_len], cap_sin[i][:cap_len]])) + + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_cos = pad_sequence(unified_cos, batch_first=True, padding_value=0.0) + unified_sin = pad_sequence(unified_sin, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + return unified, unified_cos, unified_sin, unified_attn_mask + + +def _positive_divisors(n: int) -> set[int]: + if n <= 0: + return set() + divs: set[int] = set() + for d in range(1, int(math.isqrt(n)) + 1): + if n % d == 0: + divs.add(d) + divs.add(n // d) + return divs + + +def _get_tensor_parallel_size_from_context() -> int: + if not is_forward_context_available(): + return 1 + try: + od_config = get_forward_context().omni_diffusion_config + if od_config is None: + return 1 + return int(od_config.parallel_config.tensor_parallel_size) + except Exception: + return 1 + + +def validate_zimage_tp_constraints( + *, + dim: int, + n_heads: int, + n_kv_heads: int, + in_channels: int, + all_patch_size: tuple[int, ...], + all_f_patch_size: tuple[int, ...], + tensor_parallel_size: int, +) -> tuple[int, list[int], list[int]]: + """Validate Z-Image TP constraints without requiring a distributed context. + + Returns: + (ffn_hidden_dim, final_out_dims, supported_tp_candidates) + """ + tp_size = int(tensor_parallel_size) + if tp_size <= 0: + raise ValueError(f"tensor_parallel_size must be > 0, got {tp_size}") + if dim % n_heads != 0: + raise ValueError(f"dim must be divisible by n_heads, got dim={dim}, n_heads={n_heads}") + if dim % tp_size != 0: + supported = sorted(_positive_divisors(dim)) + raise ValueError( + f"Z-Image requires dim % tensor_parallel_size == 0, but got dim={dim}, tp={tp_size}. " + f"Supported tp candidates by dim: {supported}" + ) + if n_heads % tp_size != 0: + supported = sorted(_positive_divisors(n_heads)) + raise ValueError( + f"Z-Image requires n_heads % tensor_parallel_size == 0, but got n_heads={n_heads}, tp={tp_size}. " + f"Supported tp candidates by n_heads: {supported}" + ) + if n_kv_heads % tp_size != 0: + supported = sorted(_positive_divisors(n_kv_heads)) + raise ValueError( + f"Z-Image requires n_kv_heads % tensor_parallel_size == 0, but got n_kv_heads={n_kv_heads}, " + f"tp={tp_size}. Supported tp candidates by n_kv_heads: {supported}" + ) + + ffn_hidden_dim = int(dim / 3 * 8) + if ffn_hidden_dim % tp_size != 0: + supported = sorted(_positive_divisors(ffn_hidden_dim)) + raise ValueError( + "Z-Image requires ffn_hidden_dim % tensor_parallel_size == 0 (for TP-sharded MLP), but got " + f"ffn_hidden_dim={ffn_hidden_dim}, tp={tp_size}. Supported tp candidates by ffn_hidden_dim: {supported}" + ) + + final_out_dims = [ + int(patch_size) * int(patch_size) * int(f_patch_size) * int(in_channels) + for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) + ] + bad_final_out_dims = [d for d in final_out_dims if d % tp_size != 0] + if bad_final_out_dims: + supported = sorted(_positive_divisors(math.gcd(*final_out_dims))) + raise ValueError( + "Z-Image requires final projection out_features divisible by tensor_parallel_size, but got " + f"final_out_dims={final_out_dims}, tp={tp_size}. " + f"Supported tp candidates by final_out_dims gcd: {supported}" + ) + + supported_tp_candidates = sorted( + _positive_divisors(n_heads) + & _positive_divisors(n_kv_heads) + & _positive_divisors(dim) + & _positive_divisors(ffn_hidden_dim) + & _positive_divisors(math.gcd(*final_out_dims)) + ) + return ffn_hidden_dim, final_out_dims, supported_tp_candidates + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + mid_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + mid_size, + out_size, + bias=True, + ), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class ZImageAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + qk_norm: bool = True, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.dim = dim + self.total_num_heads = num_heads + self.total_num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.head_dim, + total_num_heads=num_heads, + total_num_kv_heads=num_kv_heads, + bias=False, + ) + + assert qk_norm is True + self.norm_q = RMSNorm(self.head_dim, eps=eps) + self.norm_k = RMSNorm(self.head_dim, eps=eps) + + # NOTE: QKV is column-parallel on heads, so attention output is sharded + # on the last dim (dim / tp). Use row-parallel output projection to + # all-reduce back to full dim. + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + dim, + dim, + bias=False, + input_is_parallel=True, + return_bias=False, + ) + ] + ) + + self.attn = Attention( + num_heads=self.to_qkv.num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.to_qkv.num_kv_heads, + ) + self.rope = RotaryEmbedding(is_neox_style=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + qkv, _ = self.to_qkv(hidden_states) + q_size = self.to_qkv.num_heads * self.head_dim + kv_size = self.to_qkv.num_kv_heads * self.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) + + query = query.unflatten(-1, (self.to_qkv.num_heads, -1)) + key = key.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + value = value.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = self.attn( + query, + key, + value, + # attn_mask=attention_mask, # we don't support multi prompts now. + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + hidden_states = self.to_out[0](hidden_states) + + return hidden_states + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w13 = MergedColumnParallelLinear( + dim, + [hidden_dim] * 2, + bias=False, + return_bias=False, + ) + self.act = SiluAndMul() + self.w2 = RowParallelLinear( + hidden_dim, + dim, + bias=False, + input_is_parallel=True, + return_bias=False, + ) + + def forward(self, x): + return self.w2(self.act(self.w13(x))) + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + + self.attention = ZImageAttention( + dim=dim, + num_heads=n_heads, + num_kv_heads=n_kv_heads, + qk_norm=qk_norm, + eps=1e-5, + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + adaln_input: torch.Tensor | None = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + attention_mask=attn_mask, + cos=cos, + sin=sin, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x) * scale_mlp, + ) + ) + else: + # Attention block + attn_out = self.attention( + self.attention_norm1(x), + attention_mask=attn_mask, + cos=cos, + sin=sin, + ) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c): + scale = 1.0 + self.adaLN_modulation(c) + x = self.norm_final(x) * scale.unsqueeze(1) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: list[int] = (16, 56, 56), + axes_lens: list[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.cos_cached = None + self.sin_cached = None + + @staticmethod + def precompute_freqs(dim: list[int], end: list[int], theta: float = 256.0): + with torch.device("cpu"): + cos_list = [] + sin_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + cos_list.append(torch.cos(freqs)) + sin_list.append(torch.sin(freqs)) + + return cos_list, sin_list + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.cos_cached is None: + self.cos_cached, self.sin_cached = self.precompute_freqs(self.axes_dims, self.axes_lens, theta=self.theta) + self.cos_cached = [c.to(device) for c in self.cos_cached] + self.sin_cached = [s.to(device) for s in self.sin_cached] + else: + # Ensure cached tensors are on the same device as ids + if self.cos_cached[0].device != device: + self.cos_cached = [c.to(device) for c in self.cos_cached] + self.sin_cached = [s.to(device) for s in self.sin_cached] + + cos_result = [] + sin_result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + cos_result.append(self.cos_cached[i][index]) + sin_result.append(self.sin_cached[i][index]) + + return torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1) + + +class ZImageTransformer2DModel(CachedTransformer): + """Z-Image Transformer model for image generation. + + Sequence Parallelism: + This model supports non-intrusive SP via _sp_plan. The plan specifies: + - Input splitting at first main transformer block (unified sequence) + - RoPE (cos/sin) splitting along sequence dimension + - Attention mask splitting along sequence dimension + - Output gathering at final_layer + + The SP is applied to the main `layers` transformer blocks where the + unified image+caption sequence is processed jointly. + + Note: noise_refiner and context_refiner are NOT parallelized as they + process image and caption separately before unification. + + Important: The default _sp_plan assumes patch_size=2 and f_patch_size=1. + If using different patch configurations, update _sp_plan accordingly. + + Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers. + """ + + _repeated_blocks = ["ZImageTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "w13": ["w1", "w3"], + } + + # Sequence Parallelism for Z-Image (following diffusers' _cp_plan pattern) + # Similar to how Wan uses `rope` module's split_output to shard rotary embeddings, + # Z-Image uses `unified_prepare` module's split_output to shard unified tensors. + # + # The _sp_plan specifies sharding/gathering at module boundaries: + # - unified_prepare: Split all 4 outputs (unified, cos, sin, attn_mask) via split_output=True + # - layers.0: hidden_states input is already sharded from unified_prepare output + # - all_final_layer.2-1: Gather outputs after the final layer + # + # Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism) + _sp_plan = { + # Shard unified_prepare outputs (similar to Wan's rope module) + # This shards all 4 return values: unified, unified_cos, unified_sin, unified_attn_mask + "unified_prepare": { + 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True), # unified + 1: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True), # unified_cos + 2: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True), # unified_sin + 3: SequenceParallelInput(split_dim=1, expected_dims=2, split_output=True), # unified_attn_mask + }, + # Gather output at final_layer (default: patch_size=2, f_patch_size=1) + "all_final_layer.2-1": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + ) -> None: + super().__init__() + self.dtype = torch.bfloat16 + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + tp_size = _get_tensor_parallel_size_from_context() + ffn_hidden_dim, final_out_dims, supported_tp_candidates = validate_zimage_tp_constraints( + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + in_channels=self.out_channels, + all_patch_size=tuple(all_patch_size), + all_f_patch_size=tuple(all_f_patch_size), + tensor_parallel_size=tp_size, + ) + + logger.info_once( + "Z-Image init: dim=%d n_heads=%d n_kv_heads=%d ffn_hidden_dim=%d final_out_dims=%s tp=%d (supported_tp=%s)", + dim, + n_heads, + n_kv_heads, + ffn_hidden_dim, + tuple(final_out_dims), + tp_size, + tuple(supported_tp_candidates), + ) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + # UnifiedPrepare module for combining x and cap tensors + # This enables _cp_plan to shard outputs via split_output=True + # Similar to how Wan's rope module enables rotary embedding sharding + self.unified_prepare = UnifiedPrepare() + + def unpatchify(self, x: list[torch.Tensor], size: list[tuple], patch_size, f_patch_size) -> list[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: list[torch.Tensor], + all_cap_feats: list[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: list[torch.Tensor], + t, + cap_feats: list[torch.Tensor], + patch_size=2, + f_patch_size=1, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_cos, x_sin = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) + x_cos = list(x_cos.split(x_item_seqlens, dim=0)) + x_sin = list(x_sin.split(x_item_seqlens, dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_cos = pad_sequence(x_cos, batch_first=True, padding_value=0.0) + x_sin = pad_sequence(x_sin, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_cos, x_sin, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_cos, cap_sin = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)) + cap_cos = list(cap_cos.split(cap_item_seqlens, dim=0)) + cap_sin = list(cap_sin.split(cap_item_seqlens, dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_cos = pad_sequence(cap_cos, batch_first=True, padding_value=0.0) + cap_sin = pad_sequence(cap_sin, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_cos, cap_sin) + + # Prepare unified tensors via UnifiedPrepare module + # This enables _cp_plan to shard outputs via split_output=True + unified, unified_cos, unified_sin, unified_attn_mask = self.unified_prepare( + x, x_cos, x_sin, cap_feats, cap_cos, cap_sin, x_item_seqlens, cap_item_seqlens + ) + + # Main transformer blocks + for layer in self.layers: + unified = layer(unified, unified_attn_mask, unified_cos, unified_sin, adaln_input) + + # Final layer + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + return x, {} + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self-attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # ffn + (".w13", ".w1", 0), + (".w13", ".w3", 1), + ] + + params_dict = dict(self.named_parameters()) + + loaded_params = set[str]() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/offload.py b/vllm_omni/diffusion/offload.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1d8a0db06be8750b02af0720bd449d46b38e3b --- /dev/null +++ b/vllm_omni/diffusion/offload.py @@ -0,0 +1,535 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CPU offloading utilities for diffusion models. + +This module provides mutual-exclusion CPU offloading between DiT and encoders. +When enable_cpu_offload is enabled: +- Text encoders run on GPU while DiT is on CPU +- DiT runs on GPU while encoders are offloaded to CPU + +This allows running large models on limited GPU memory. +""" + +from __future__ import annotations + +from functools import partial +from itertools import chain +from typing import TYPE_CHECKING, Any + +import torch +from torch import nn +from vllm.logger import init_logger + +from vllm_omni.platforms import current_omni_platform + +if TYPE_CHECKING: + from vllm_omni.diffusion.data import OmniDiffusionConfig + +logger = init_logger(__name__) + + +class SequentialOffloader: + """Sequential offloader: DiT and encoders take turns on GPU. + + Uses PyTorch's forward pre-hooks to automatically swap models: + - Before encoder runs: move DiT modules to CPU, move encoder to GPU + - Before DiT runs: move encoders to CPU, move active DiT to GPU + + This ensures only one large model group is on GPU at a time. + """ + + def __init__( + self, + dits: list[nn.Module], + encoders: list[nn.Module], + device: torch.device, + pin_memory: bool = True, + ): + assert all(isinstance(m, nn.Module) for m in dits), "All dits must be nn.Module" + assert all(isinstance(m, nn.Module) for m in encoders), "All encoders must be nn.Module" + self.dits = dits + self.encoders = encoders + self.device = device + self.pin_memory = pin_memory + self._handles: list = [] + + def _to_cpu(self, module: nn.Module) -> None: + """Move module to CPU with optional memory pinning.""" + # Skip if already on CPU + try: + param = next(module.parameters()) + if param.device.type == "cpu": + return + except StopIteration: + return + + previous_device = param.device + module.to("cpu", non_blocking=True) + + # Release allocator blocks when tensors leave the GPU. + if previous_device.type != "cpu": + torch.cuda.empty_cache() + + if self.pin_memory: + for p in module.parameters(): + if p.data.device.type == "cpu" and not p.data.is_pinned(): + p.data = p.data.pin_memory() + + def _to_gpu(self, module: nn.Module) -> None: + """Move module to GPU.""" + # Skip if already on target device + try: + if next(module.parameters()).device == self.device: + return + except StopIteration: + return + + module.to(self.device, non_blocking=True) + + def _dit_pre_hook(self, module: nn.Module, args: tuple) -> None: + """Before DiT forward: offload encoders, load DiT.""" + for enc in self.encoders: + self._to_cpu(enc) + self._to_gpu(module) + + current_omni_platform.synchronize() + + logger.debug("Swapped: encoders -> CPU, DiT -> GPU") + + def _encoder_pre_hook(self, module: nn.Module, args: tuple) -> None: + """Before encoder forward: offload DiT, load encoder.""" + for dit_mod in self.dits: + self._to_cpu(dit_mod) + self._to_gpu(module) + + current_omni_platform.synchronize() + + logger.debug("Swapped: DiT -> CPU, encoder -> GPU") + + def register(self) -> None: + """Register forward pre-hooks on DiT and encoders.""" + # Hook on each DiT-like module + for dit_mod in self.dits: + h = dit_mod.register_forward_pre_hook(self._dit_pre_hook) + self._handles.append(h) + logger.debug("Registered offload hook for %s", dit_mod.__class__.__name__) + + # Hook on each encoder + for enc in self.encoders: + h = enc.register_forward_pre_hook(self._encoder_pre_hook) + self._handles.append(h) + logger.debug("Registered offload hook for %s", enc.__class__.__name__) + + def remove(self) -> None: + """Remove all hooks.""" + for h in self._handles: + h.remove() + self._handles = [] + + +class LayerwiseOffloader: + """Layer-wise CPU offloading for transformer blocks. + + Keeps only a sliding window of layers (blocks), by default a single layer, on GPU, + prefetching the next block while the current block computes to approach compute - memcpy overlap. + Unused blocks are freed on GPU. + + Based on implementations from: + https://github.com/sgl-project/sglang/blob/v0.5.8/python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py + """ + + def __init__( + self, + blocks: list[nn.Module], + device: torch.device, + pin_memory: bool = True, + num_gpu_layers: int = 1, + ): + assert all(isinstance(m, nn.Module) for m in blocks), "All transformer blocks must be torch.nn.Module" + assert current_omni_platform.is_cuda(), "Layerwise offloading is only supported on cuda devices for now" + + self.blocks = blocks + self.device = device + self.pin_memory = pin_memory + self.num_gpu_layers = num_gpu_layers + self.num_blocks = len(self.blocks) + if self.num_blocks == 0: + raise ValueError("LayerwiseOffloader requires at least one block, but found 0.") + if not (1 <= self.num_gpu_layers <= self.num_blocks): + raise ValueError(f"Invalid num_gpu_layers {self.num_gpu_layers} with {self.num_blocks} blocks") + + self._pre_hook_handles: list = [] + self._post_hook_handles: list = [] + + self._copy_stream = torch.cuda.Stream() + + # Per-layer synchronization primitive: set after H2D copy completes. + self._prefetch_done: list[torch.cuda.Event | None] = [None] * self.num_blocks + + # Simple state to avoid redundant work. + self._resident: list[bool] = [False] * self.num_blocks + + # Pre-allocate gpu tensors + # layer-id -> {dtype -> flattened aggregated cpu tensor} + self.layer_cpu_weights: list[dict[torch.dtype, torch.Tensor]] = [] + self.layer_metadata: list[dict[torch.dtype, list[dict[str, Any]]]] = [] + + self.block_parameters: dict[int, dict[str, nn.Parameter]] = {} + self.block_buffers: dict[int, dict[str, torch.Tensor]] = {} + for layer_idx, block in enumerate(self.blocks): + self.block_parameters[layer_idx] = dict(block.named_parameters()) + self.block_buffers[layer_idx] = dict(block.named_buffers()) + + dtype_cpu_flattened_weights, dtype_metadata = self._to_cpu( + self.block_parameters[layer_idx], self.block_buffers[layer_idx] + ) + self.layer_cpu_weights.append(dtype_cpu_flattened_weights) + self.layer_metadata.append(dtype_metadata) + + if self.num_blocks != len(self.layer_cpu_weights): + logger.error( + f"Inconsistent block layers happened: # of blocks: {self.num_blocks}; " + f"# of layer cpu weights: {len(self.layer_cpu_weights)}" + ) + + # Register pre and post forward hooks on each of the blocks + self.register_block_hooks() + + # Pre-fetch the first layer + # For subsequent requests, the first layer/block will be pre-fetched + # during the last layer compute of the previous request. + self.prefetch_layer(0, non_blocking=False) + + def _to_cpu( + self, params: dict[str, nn.Parameter], bufs: dict[str, torch.Tensor] + ) -> tuple[dict[torch.dtype, torch.Tensor], dict[torch.dtype, list[dict[str, Any]]]]: + """Move block parameters and buffers to CPU, flattening by dtype. + + Consolidates parameters and buffers into contiguous CPU tensors grouped by dtype + for GPU transfers. Replaces original tensors with empty placeholders. + + Returns: + Tuple of + flattened CPU tensors by dtype, + metadata for reconstruction by dtype + """ + dtype_grouped_weights: dict[torch.dtype, dict[str, torch.Tensor]] = {} + dtype_cpu_flattened_weights: dict[torch.dtype, torch.Tensor] = {} + # order does matter + dtype_metadata: dict[torch.dtype, list[dict[str, Any]]] = {} + + for name, param_or_buf in chain(params.items(), bufs.items()): + dtype = param_or_buf.dtype + if dtype not in dtype_grouped_weights: + dtype_grouped_weights[dtype] = {} + dtype_grouped_weights[dtype][name] = param_or_buf + + for dtype, name2weights in dtype_grouped_weights.items(): + # total # of parameters + buffers + total_numel = sum(t.numel() for _, t in name2weights.items()) + cpu_tensor = torch.empty(total_numel, dtype=dtype, device="cpu", pin_memory=self.pin_memory) + + current_offset = 0 + for name, param_or_buf in name2weights.items(): + numel = param_or_buf.numel() + cpu_tensor[current_offset : current_offset + numel].copy_(param_or_buf.flatten()) + if dtype not in dtype_metadata: + dtype_metadata[dtype] = [] + dtype_metadata[dtype].append( + { + "name": name, + "offset": current_offset, + "numel": numel, + "shape": param_or_buf.shape, + } + ) + + param_or_buf.data = torch.empty((), device=self.device, dtype=dtype) + current_offset += numel + + dtype_cpu_flattened_weights[dtype] = cpu_tensor + + return dtype_cpu_flattened_weights, dtype_metadata + + def register_block_hooks(self) -> None: + """Register forward hooks on blocks for prefetching and offloading.""" + + def _pre_hook(module: nn.Module, args: tuple, *, layer_idx: int) -> None: + # For the last block / layer, prefetch layer 0 (the first layer) + next_id = (layer_idx + 1) % self.num_blocks + self.prefetch_layer(next_id, non_blocking=True) + + def _post_hook(module: nn.Module, args: tuple, output: tuple, *, layer_idx: int) -> None: + self.offload_layer(layer_idx) + self._resident[layer_idx] = False + self._prefetch_done[layer_idx] = None + + for i, layer in enumerate(self.blocks): + pre_hook_fn = partial(_pre_hook, layer_idx=i) + handle = layer.register_forward_pre_hook(pre_hook_fn) + self._pre_hook_handles.append(handle) + + post_hook_fn = partial(_post_hook, layer_idx=i) + handle = layer.register_forward_hook(post_hook_fn) + self._post_hook_handles.append(handle) + + @torch.compiler.disable + def prefetch_layer(self, layer_idx: int, non_blocking: bool = True) -> None: + """Copy layer weights from CPU -> GPU. + + Pre-fetch target layer in an asynchronous way with compute - memory copy overlap, + with non_blocking set to True. + """ + if layer_idx >= self.num_blocks or layer_idx < 0: + logger.warning(f"Invalid layer id specified: {layer_idx}") + return + + self._copy_stream.wait_stream(torch.cuda.current_stream()) + + layers_to_fetch = [(layer_idx + i) % self.num_blocks for i in range(self.num_gpu_layers)] + + for idx in layers_to_fetch: + if self._resident[idx]: + continue + + layer_params = self.block_parameters[idx] + layer_bufs = self.block_buffers[idx] + + evt = torch.cuda.Event() + gpu_weights: dict[torch.dtype, torch.Tensor] = {} + + with torch.cuda.stream(self._copy_stream): + for dtype, cpu_weight in self.layer_cpu_weights[idx].items(): + gpu_weight = torch.empty(cpu_weight.shape, dtype=dtype, device=self.device) + gpu_weight.copy_(cpu_weight, non_blocking=non_blocking) + gpu_weights[dtype] = gpu_weight + + evt.record(self._copy_stream) + + for dtype in self.layer_metadata[idx]: + ordered_metadata: list[dict[str, Any]] = self.layer_metadata[idx][dtype] + + gpu_weight = gpu_weights[dtype] + + for metadata in ordered_metadata: + target_name = metadata["name"] + target_param_or_buf = ( + layer_params[target_name] if target_name in layer_params else layer_bufs[target_name] + ) + + target_param_or_buf.data = gpu_weight[ + metadata["offset"] : metadata["offset"] + metadata["numel"] + ].view(metadata["shape"]) + + self._prefetch_done[idx] = evt + self._resident[idx] = True + + @torch.compiler.disable + def offload_layer(self, layer_idx: int) -> None: + """Free GPU memory for layer by replacing tensors with empty placeholders.""" + if layer_idx >= self.num_blocks or layer_idx < 0: + logger.warning(f"Invalid layer id specified: {layer_idx}") + return + if not self._resident[layer_idx]: + logger.warning(f"{layer_idx} is not residing on GPU") + return + + evt = self._prefetch_done[layer_idx] + if evt is not None: + torch.cuda.current_stream().wait_event(evt) + + # free GPU residency + for _, param in self.block_parameters[layer_idx].items(): + param.data = torch.empty((), device=self.device, dtype=param.dtype) + for _, buf in self.block_buffers[layer_idx].items(): + buf.data = torch.empty((), device=self.device, dtype=buf.dtype) + + def remove_all_hooks(self) -> None: + """Remove all hooks.""" + for h in self._pre_hook_handles: + h.remove() + for h in self._post_hook_handles: + h.remove() + self._pre_hook_handles.clear() + self._post_hook_handles.clear() + + @staticmethod + def get_blocks_attr_name(model: nn.Module) -> str | None: + """Retrieve blocks attribute name from provided DiT model""" + return getattr(model.__class__, "_layerwise_offload_blocks_attr", None) + + @staticmethod + def get_blocks_from_dit(model: nn.Module) -> list[nn.Module]: + """ + Retrieve a list of blocks from provided DiT model. Blocks attribute name + are found by `_layerwise_offload_blocks_attr` set to DiT models. For example, + + ``` + class WanTransformer3DModel(nn.Module): + _layerwise_offload_blocks_attr = "blocks" + ``` + """ + blocks_attr_name = LayerwiseOffloader.get_blocks_attr_name(model) + if blocks_attr_name is None: + logger.warning( + f"No _layerwise_offload_blocks_attr defined for {model.__class__.__name__}, " + "skipping layerwise offloading" + ) + return [] + + _blocks = getattr(model, blocks_attr_name, None) + if _blocks is None: + logger.warning( + f"Blocks (layers) '{blocks_attr_name}' not found on {model.__class__.__name__}, " + "skipping layerwise offloading" + ) + return [] + + return list(_blocks) + + +def apply_offload_hooks( + model: nn.Module, + od_config: OmniDiffusionConfig, + *, + device: torch.device | None = None, +) -> None: + """Apply mutual-exclusion offload hooks based on config. + + When enable_cpu_offload is enabled, DiT and encoders swap GPU access: + - Encoders (text_encoder, text_encoder_2, text_encoder_3, image_encoder) + run on GPU while DiT is on CPU + - DiT runs on GPU while encoders are on CPU + + Args: + model: Diffusion pipeline model + od_config: OmniDiffusionConfig with offload settings + """ + enable_cpu_offload = getattr(od_config, "enable_cpu_offload", False) + enable_layerwise_offload = getattr(od_config, "enable_layerwise_offload", False) + pin_cpu_memory = getattr(od_config, "pin_cpu_memory", True) + + if not enable_cpu_offload and not enable_layerwise_offload: + return + if enable_cpu_offload and enable_layerwise_offload: + # NOTE: Model-wise and layerwise cpu offloading are not supported together at this moment, + # consider layerwise offloading has higher priority than model-wise offloading + enable_cpu_offload = False + logger.info( + "Model-wise and layer-wise CPU offloading are not supported together at this moment. " + "Automatically disabled model-wise offloading." + ) + # For now, model-wise and layer-wise (block-wise) offloading + # are functioning as expected when cuda device is available + if not current_omni_platform.is_cuda() or current_omni_platform.get_device_count() < 1: + logger.info("CPU Offloading requires cuda devices available. Skipping for now...") + return + + # Find DiT/transformer modules + dit_modules: list[nn.Module] = [] + dit_names: list[str] = [] + candidate_attrs = ["transformer", "transformer_2", "dit"] + for attr in candidate_attrs: + if not hasattr(model, attr): + continue + module_obj = getattr(model, attr) + if module_obj is None: + continue + + assert isinstance(module_obj, nn.Module), f"Expected {attr} to be nn.Module, got {type(module_obj)!r}" + + if module_obj in dit_modules: + continue + + dit_modules.append(module_obj) + dit_names.append(attr) + + if not dit_modules: + logger.warning("enable_cpu_offload enabled but no transformer/dit/unet found") + return + if device is None: + try: + device = next(dit_modules[0].parameters()).device + except StopIteration: + try: + device = current_omni_platform.get_torch_device() + except (NotImplementedError, AttributeError): + logger.error("Fail to get device of pipeline. Skipping applying offloading hooks") + return + + # Collect all encoders + encoders: list[nn.Module] = [] + encoder_names: list[str] = [] + for attr in ["text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder"]: + if hasattr(model, attr) and getattr(model, attr) is not None: + encoders.append(getattr(model, attr)) + encoder_names.append(attr) + if not encoders and enable_cpu_offload: + logger.warning("enable_cpu_offload enabled but no encoders found") + return + for enc in encoders: + enc.to(device) + + # Collect VAE + for name in ["vae"]: + module = getattr(model, name, None) + if module is None: + continue + try: + module.to(device, non_blocking=True) + except Exception as exc: + logger.debug("Failed to move %s to GPU: %s", name, exc) + + if enable_cpu_offload: + # Initial state: keep DiT modules on CPU (encoders typically run first) + for dit_mod in dit_modules: + dit_mod.to("cpu") + + torch.cuda.empty_cache() + + if pin_cpu_memory: + for dit_mod in dit_modules: + for p in dit_mod.parameters(): + if p.data.device.type == "cpu" and not p.data.is_pinned(): + p.data = p.data.pin_memory() + + # Register sequential offload hooks + SequentialOffloader(dit_modules, encoders, device, pin_cpu_memory).register() + logger.info( + "CPU offload enabled: %s <-> %s (mutual exclusion)", + ", ".join(dit_names), + ", ".join(encoder_names), + ) + elif enable_layerwise_offload: + logger.info(f"Applying offloading hooks on {dit_names}") + + for i, dit_module in enumerate(dit_modules): + logger.info(f"Applying hook on {dit_names[i]} ({dit_module.__class__.__name__})") + blocks_attr_name = LayerwiseOffloader.get_blocks_attr_name(dit_module) + blocks = LayerwiseOffloader.get_blocks_from_dit(dit_module) + + if not blocks_attr_name or not blocks: + logger.warning( + "Target layers (blocks) are not found. " + f"Skipping offloading on {dit_names[i]} ({dit_module.__class__.__name__})" + ) + continue + + # move modules other than blocks to gpu and keep them on gpu + for name, m in dit_module.named_children(): + # Skip the blocks module (layers to be offloaded) + if name == blocks_attr_name: + logger.debug(f"Skipped module {name}") + continue + + m.to(device) + logger.debug(f"Moved {name} to device {device}") + + # set to the module (transformer) + offloader = LayerwiseOffloader(blocks, device, pin_cpu_memory, od_config.layerwise_num_gpu_layers) + setattr(dit_module, "_layerwise_offloader", offloader) + + logger.info( + f"Layerwise offloading enabled on {len(blocks)} layers (blocks), " + f"with {od_config.layerwise_num_gpu_layers} kept on device" + ) diff --git a/vllm_omni/diffusion/profiler/__init__.py b/vllm_omni/diffusion/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df505cbaf67534938eb85b4a1f2d13661caa9eb4 --- /dev/null +++ b/vllm_omni/diffusion/profiler/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .torch_profiler import TorchProfiler + +# Default profiler – can be changed later via config +CurrentProfiler = TorchProfiler + +__all__ = ["CurrentProfiler", "TorchProfiler"] diff --git a/vllm_omni/diffusion/profiler/base.py b/vllm_omni/diffusion/profiler/base.py new file mode 100644 index 0000000000000000000000000000000000000000..640e406da9594838e43c46f12a22ea3758b3d599 --- /dev/null +++ b/vllm_omni/diffusion/profiler/base.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class ProfilerBase(ABC): + """ + Abstract base class for all diffusion profilers. + Defines the common interface used by GPUWorker and DiffusionEngine. + """ + + @abstractmethod + def start(self, trace_path_template: str) -> str: + """ + Start profiling. + + Args: + trace_path_template: Base path (without rank or extension). + e.g. "/tmp/profiles/sdxl_run" + + Returns: + Full path of the trace file this rank will write. + """ + pass + + @abstractmethod + def stop(self) -> str | None: + """ + Stop profiling and finalize/output the trace. + + Returns: + Path to the saved trace file, or None if not active. + """ + pass + + @abstractmethod + def get_step_context(self): + """ + Returns a context manager that advances one profiling step. + Should be a no-op (nullcontext) when profiler is not active. + """ + pass + + @abstractmethod + def is_active(self) -> bool: + """Return True if profiling is currently running.""" + pass + + @classmethod + def _get_rank(cls) -> int: + import os + + return int(os.getenv("RANK", "0")) diff --git a/vllm_omni/diffusion/profiler/torch_profiler.py b/vllm_omni/diffusion/profiler/torch_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..37c457100716dc571c1f9d0d5d7944ce1f21f93c --- /dev/null +++ b/vllm_omni/diffusion/profiler/torch_profiler.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import subprocess +from contextlib import nullcontext + +import torch +from torch.profiler import ProfilerActivity, profile +from vllm.logger import init_logger + +from .base import ProfilerBase + +logger = init_logger(__name__) + + +class TorchProfiler(ProfilerBase): + """ + Torch-based profiler configured for End-to-End continuous recording. + Uses 'on_trace_ready' to handle Trace export. + Compression is offloaded to a background subprocess to avoid blocking the worker loop. + """ + + _profiler: profile | None = None + _trace_template: str = "" + + @classmethod + def start(cls, trace_path_template: str) -> str: + """ + Start the profiler with the given trace path template. + """ + # 1. Cleanup any existing profiler + if cls._profiler is not None: + logger.warning("[Rank %s] Stopping existing Torch profiler", cls._get_rank()) + cls._profiler.stop() + cls._profiler = None + + rank = cls._get_rank() + + # 2. Make path absolute + trace_path_template = os.path.abspath(trace_path_template) + cls._trace_template = trace_path_template + + # Expected paths + json_file = f"{trace_path_template}_rank{rank}.json" + + os.makedirs(os.path.dirname(json_file), exist_ok=True) + + logger.info(f"[Rank {rank}] Starting End-to-End Torch profiler") + + # 3. Define the on_trace_ready handler + def trace_handler(p): + nonlocal json_file + + # A. Export JSON Trace + try: + p.export_chrome_trace(json_file) + logger.info(f"[Rank {rank}] Trace exported to {json_file}") + + try: + subprocess.Popen(["gzip", "-f", json_file]) + logger.info(f"[Rank {rank}] Triggered background compression for {json_file}") + # Update variable to point to the eventual file + json_file = f"{json_file}.gz" + except Exception as compress_err: + logger.warning(f"[Rank {rank}] Background gzip failed to start: {compress_err}") + + except Exception as e: + logger.warning(f"[Rank {rank}] Failed to export trace: {e}") + + # 4. Initialize profiler with long active period + cls._profiler = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=0, + warmup=0, + active=100000, # long capture window + ), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + ) + + # 5. Start profiling + cls._profiler.start() + + # Return the expected final path + return f"{trace_path_template}_rank{rank}.json.gz" + + @classmethod + def stop(cls) -> dict | None: + if cls._profiler is None: + return None + + rank = cls._get_rank() + + # Determine expected paths + base_path = f"{cls._trace_template}_rank{rank}" + gz_path = f"{base_path}.json.gz" + + try: + # This triggers trace_handler synchronously + # Since we removed table generation and backgrounded compression, this returns fast. + cls._profiler.stop() + except Exception as e: + logger.warning(f"[Rank {rank}] Profiler stop failed: {e}") + + cls._profiler = None + + # We return the .gz path assuming background compression will succeed. + return {"trace": gz_path, "table": None} + + @classmethod + def step(cls): + if cls._profiler is not None: + cls._profiler.step() + + @classmethod + def is_active(cls) -> bool: + return cls._profiler is not None + + @classmethod + def get_step_context(cls): + return nullcontext() diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0243524793fdde2ea20fa61a15d992bd77fb489a --- /dev/null +++ b/vllm_omni/diffusion/registry.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib + +import torch.nn as nn +from vllm.logger import init_logger +from vllm.model_executor.models.registry import _LazyRegisteredModel, _ModelRegistry + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig, get_sp_plan_from_model +from vllm_omni.diffusion.hooks.sequence_parallel import apply_sequence_parallel + +logger = init_logger(__name__) + +_DIFFUSION_MODELS = { + # arch:(mod_folder, mod_relname, cls_name) + "QwenImagePipeline": ( + "qwen_image", + "pipeline_qwen_image", + "QwenImagePipeline", + ), + "QwenImageEditPipeline": ( + "qwen_image", + "pipeline_qwen_image_edit", + "QwenImageEditPipeline", + ), + "QwenImageEditPlusPipeline": ( + "qwen_image", + "pipeline_qwen_image_edit_plus", + "QwenImageEditPlusPipeline", + ), + "QwenImageLayeredPipeline": ( + "qwen_image", + "pipeline_qwen_image_layered", + "QwenImageLayeredPipeline", + ), + "GlmImagePipeline": ( + "glm_image", + "pipeline_glm_image", + "GlmImagePipeline", + ), + "ZImagePipeline": ( + "z_image", + "pipeline_z_image", + "ZImagePipeline", + ), + "OvisImagePipeline": ( + "ovis_image", + "pipeline_ovis_image", + "OvisImagePipeline", + ), + "WanPipeline": ( + "wan2_2", + "pipeline_wan2_2", + "Wan22Pipeline", + ), + "StableAudioPipeline": ( + "stable_audio", + "pipeline_stable_audio", + "StableAudioPipeline", + ), + "WanImageToVideoPipeline": ( + "wan2_2", + "pipeline_wan2_2_i2v", + "Wan22I2VPipeline", + ), + "LongCatImagePipeline": ( + "longcat_image", + "pipeline_longcat_image", + "LongCatImagePipeline", + ), + "BagelPipeline": ( + "bagel", + "pipeline_bagel", + "BagelPipeline", + ), + "LongCatImageEditPipeline": ( + "longcat_image", + "pipeline_longcat_image_edit", + "LongCatImageEditPipeline", + ), + "StableDiffusion3Pipeline": ( + "sd3", + "pipeline_sd3", + "StableDiffusion3Pipeline", + ), + "Flux2KleinPipeline": ( + "flux2_klein", + "pipeline_flux2_klein", + "Flux2KleinPipeline", + ), + "FluxPipeline": ( + "flux", + "pipeline_flux", + "FluxPipeline", + ), +} + + +DiffusionModelRegistry = _ModelRegistry( + { + model_arch: _LazyRegisteredModel( + module_name=f"vllm_omni.diffusion.models.{mod_folder}.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_folder, mod_relname, cls_name) in _DIFFUSION_MODELS.items() + } +) + + +def initialize_model( + od_config: OmniDiffusionConfig, +) -> nn.Module: + """Initialize a diffusion model from the registry. + + This function: + 1. Loads the model class from the registry + 2. Instantiates the model with the config + 3. Configures VAE optimization settings + 4. Applies sequence parallelism if enabled (similar to diffusers' enable_parallelism) + + Args: + od_config: The OmniDiffusion configuration. + + Returns: + The initialized pipeline model. + + Raises: + ValueError: If the model class is not found in the registry. + """ + model_class = DiffusionModelRegistry._try_load_model_cls(od_config.model_class_name) + if model_class is not None: + model = model_class(od_config=od_config) + # Configure VAE memory optimization settings from config + if hasattr(model.vae, "use_slicing"): + model.vae.use_slicing = od_config.vae_use_slicing + if hasattr(model.vae, "use_tiling"): + model.vae.use_tiling = od_config.vae_use_tiling + + # Apply sequence parallelism if enabled + # This follows diffusers' pattern where enable_parallelism() is called + # at model loading time, not inside individual model files + _apply_sequence_parallel_if_enabled(model, od_config) + + return model + else: + raise ValueError(f"Model class {od_config.model_class_name} not found in diffusion model registry.") + + +def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -> None: + """Apply sequence parallelism hooks if SP is enabled. + + This is the centralized location for enabling SP, similar to diffusers' + ModelMixin.enable_parallelism() method. It applies _sp_plan hooks to + transformer models that define them. + + Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers. + We use _sp_plan instead of diffusers' _cp_plan. + + Args: + model: The pipeline model (e.g., ZImagePipeline). + od_config: The OmniDiffusion configuration. + """ + + try: + sp_size = od_config.parallel_config.sequence_parallel_size + if sp_size <= 1: + return + + # Find transformer model(s) in the pipeline that have _sp_plan + # Include transformer_2 for two-stage models (e.g., Wan MoE) + transformer_attrs = ["transformer", "transformer_2", "dit", "unet"] + applied_count = 0 + + for attr in transformer_attrs: + if not hasattr(model, attr): + continue + + transformer = getattr(model, attr) + if transformer is None: + continue + + plan = get_sp_plan_from_model(transformer) + if plan is None: + continue + + # Create SP config + sp_config = SequenceParallelConfig( + ulysses_degree=od_config.parallel_config.ulysses_degree, + ring_degree=od_config.parallel_config.ring_degree, + ) + + # Apply hooks according to the plan + mode = ( + "hybrid" + if sp_config.ulysses_degree > 1 and sp_config.ring_degree > 1 + else ("ulysses" if sp_config.ulysses_degree > 1 else "ring") + ) + logger.info( + f"Applying sequence parallelism to {transformer.__class__.__name__} ({attr}) " + f"(sp_size={sp_size}, mode={mode}, ulysses={sp_config.ulysses_degree}, ring={sp_config.ring_degree})" + ) + apply_sequence_parallel(transformer, sp_config, plan) + applied_count += 1 + + if applied_count == 0: + logger.warning( + f"Sequence parallelism is enabled (sp_size={sp_size}) but no transformer with _sp_plan found. " + "SP hooks not applied. Consider adding _sp_plan to your transformer model." + ) + + except Exception as e: + logger.warning(f"Failed to apply sequence parallelism: {e}. Continuing without SP hooks.") + + +_DIFFUSION_POST_PROCESS_FUNCS = { + # arch: post_process_func + # `post_process_func` function must be placed in {mod_folder}/{mod_relname}.py, + # where mod_folder and mod_relname are defined and mapped using `_DIFFUSION_MODELS` via the `arch` key + "QwenImagePipeline": "get_qwen_image_post_process_func", + "QwenImageEditPipeline": "get_qwen_image_edit_post_process_func", + "QwenImageEditPlusPipeline": "get_qwen_image_edit_plus_post_process_func", + "GlmImagePipeline": "get_glm_image_post_process_func", + "ZImagePipeline": "get_post_process_func", + "OvisImagePipeline": "get_ovis_image_post_process_func", + "WanPipeline": "get_wan22_post_process_func", + "StableAudioPipeline": "get_stable_audio_post_process_func", + "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", + "LongCatImagePipeline": "get_longcat_image_post_process_func", + "BagelPipeline": "get_bagel_post_process_func", + "LongCatImageEditPipeline": "get_longcat_image_post_process_func", + "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", + "Flux2KleinPipeline": "get_flux2_klein_post_process_func", + "FluxPipeline": "get_flux_post_process_func", +} + +_DIFFUSION_PRE_PROCESS_FUNCS = { + # arch: pre_process_func + # `pre_process_func` function must be placed in {mod_folder}/{mod_relname}.py, + # where mod_folder and mod_relname are defined and mapped using `_DIFFUSION_MODELS` via the `arch` key + "GlmImagePipeline": "get_glm_image_pre_process_func", + "QwenImageEditPipeline": "get_qwen_image_edit_pre_process_func", + "QwenImageEditPlusPipeline": "get_qwen_image_edit_plus_pre_process_func", + "LongCatImageEditPipeline": "get_longcat_image_edit_pre_process_func", + "QwenImageLayeredPipeline": "get_qwen_image_layered_pre_process_func", + "WanPipeline": "get_wan22_pre_process_func", + "WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func", +} + + +def _load_process_func(od_config: OmniDiffusionConfig, func_name: str): + """Load and return a process function from the appropriate module.""" + mod_folder, mod_relname, _ = _DIFFUSION_MODELS[od_config.model_class_name] + module_name = f"vllm_omni.diffusion.models.{mod_folder}.{mod_relname}" + module = importlib.import_module(module_name) + func = getattr(module, func_name) + return func(od_config) + + +def get_diffusion_post_process_func(od_config: OmniDiffusionConfig): + if od_config.model_class_name not in _DIFFUSION_POST_PROCESS_FUNCS: + return None + func_name = _DIFFUSION_POST_PROCESS_FUNCS[od_config.model_class_name] + return _load_process_func(od_config, func_name) + + +def get_diffusion_pre_process_func(od_config: OmniDiffusionConfig): + if od_config.model_class_name not in _DIFFUSION_PRE_PROCESS_FUNCS: + return None # Return None if no pre-processing function is registered (for backward compatibility) + func_name = _DIFFUSION_PRE_PROCESS_FUNCS[od_config.model_class_name] + return _load_process_func(od_config, func_name) diff --git a/vllm_omni/diffusion/request.py b/vllm_omni/diffusion/request.py new file mode 100644 index 0000000000000000000000000000000000000000..a6005290cdc49060b51d052ffb79148a2797a7ea --- /dev/null +++ b/vllm_omni/diffusion/request.py @@ -0,0 +1,44 @@ +# adapted from sglang and fastvideo +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType + + +@dataclass +class OmniDiffusionRequest: + """ + Complete state passed through the pipeline execution. + + This dataclass contains the prompts and sampling parameters for the diffusion pipeline + execution. It also contains a request_id for other components to trace this request and its outputs. + """ + + # TODO(will): double check that args are separate from server_args + # properly. Also maybe think about providing an abstraction for pipeline + # specific arguments. + # data_type: DataType + + prompts: list[OmniPromptType] # Actually supporting str-based prompts + sampling_params: OmniDiffusionSamplingParams + + request_ids: list[str] = field(default_factory=list) + + def __post_init__(self): + """Initialize dependent fields after dataclass initialization.""" + # Set do_classifier_free_guidance based on guidance scale and negative prompt + if self.sampling_params.guidance_scale > 1.0 and any( + (not isinstance(p, str) and p.get("negative_prompt")) for p in self.prompts + ): + self.sampling_params.do_classifier_free_guidance = True + if self.sampling_params.guidance_scale_2 is None: + self.sampling_params.guidance_scale_2 = self.sampling_params.guidance_scale + + # The dataclass default value is 0 (false-like), used to detect whether user explicitly provides this value + # After this check is done, reset this value to old default 1 + if self.sampling_params.guidance_scale: + self.sampling_params.guidance_scale_provided = True + else: + self.sampling_params.guidance_scale = 1.0 diff --git a/vllm_omni/diffusion/scheduler.py b/vllm_omni/diffusion/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f104d052266690e6f5a14c21ae6994c04fc978ea --- /dev/null +++ b/vllm_omni/diffusion/scheduler.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import zmq +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = init_logger(__name__) + + +class Scheduler: + def initialize(self, od_config: OmniDiffusionConfig): + existing_context = getattr(self, "context", None) + if existing_context is not None and not existing_context.closed: + logger.warning("SyncSchedulerClient is already initialized. Re-initializing.") + self.close() + + self.num_workers = od_config.num_gpus + self.od_config = od_config + self.context = zmq.Context() # Standard synchronous context + + # Initialize single MessageQueue for all message types (generation & RPC) + # Assuming all readers are local for now as per current launch_engine implementation + self.mq = MessageQueue( + n_reader=self.num_workers, + n_local_reader=self.num_workers, + local_reader_ranks=list(range(self.num_workers)), + ) + + self.result_mq = None + + def initialize_result_queue(self, handle): + # Initialize MessageQueue for receiving results + # We act as rank 0 reader for this queue + self.result_mq = MessageQueue.create_from_handle(handle, rank=0) + logger.info("SyncScheduler initialized result MessageQueue") + + def get_broadcast_handle(self): + return self.mq.export_handle() + + def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: + """Sends a request to the scheduler and waits for the response.""" + try: + # Prepare RPC request for generation + rpc_request = { + "type": "rpc", + "method": "generate", + "args": (request,), + "kwargs": {}, + "output_rank": 0, + "exec_all_ranks": True, + } + + # Broadcast RPC request to all workers + self.mq.enqueue(rpc_request) + # Wait for result from Rank 0 (or whoever sends it) + + if self.result_mq is None: + raise RuntimeError("Result queue not initialized") + + output = self.result_mq.dequeue() + # {"status": "error", "error": str(e)} + if isinstance(output, dict) and output.get("status") == "error": + raise RuntimeError("worker error") + return output + except zmq.error.Again: + logger.error("Timeout waiting for response from scheduler.") + raise TimeoutError("Scheduler did not respond in time.") + + def close(self): + """Closes the socket and terminates the context.""" + if hasattr(self, "context"): + self.context.term() + self.context = None + self.mq = None + self.result_mq = None diff --git a/vllm_omni/diffusion/utils/__init__.py b/vllm_omni/diffusion/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/diffusion/utils/hf_utils.py b/vllm_omni/diffusion/utils/hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc1807a188fdbdf46c6fd0a96ac14c9575e6c43 --- /dev/null +++ b/vllm_omni/diffusion/utils/hf_utils.py @@ -0,0 +1,79 @@ +import os +from functools import lru_cache + +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_hf_file_to_dict + +logger = init_logger(__name__) + + +def load_diffusers_config(model_name) -> dict: + from diffusers.pipelines.pipeline_utils import DiffusionPipeline + + config = DiffusionPipeline.load_config(model_name) + return config + + +def _looks_like_bagel(model_name: str) -> bool: + """Best-effort detection for Bagel (non-diffusers) diffusion models.""" + try: + cfg = get_hf_file_to_dict("config.json", model_name) + except Exception: + return False + model_type = cfg.get("model_type") + if model_type == "bagel": + return True + architectures = cfg.get("architectures") or [] + return "BagelForConditionalGeneration" in architectures + + +@lru_cache +def is_diffusion_model(model_name: str) -> bool: + """Check if a model is a diffusion model. + + Uses multiple fallback strategies to detect diffusion models: + 1. Check local file system for model_index.json (fastest, no imports) + 2. Check using vllm's get_hf_file_to_dict utility + 3. Try the standard diffusers approach (may fail due to import issues) + """ + # Strategy 1: Check local file system first (fastest, avoids import issues) + if os.path.isdir(model_name): + model_index_path = os.path.join(model_name, "model_index.json") + if os.path.exists(model_index_path): + try: + import json + + with open(model_index_path) as f: + config_dict = json.load(f) + if config_dict.get("_class_name") and config_dict.get("_diffusers_version"): + logger.debug("Detected diffusion model via local model_index.json") + return True + except Exception as e: + logger.debug("Failed to read local model_index.json: %s", e) + + # Strategy 2: Check using vllm's utility (works for both local and remote models) + try: + # Try to get model_index.json using vllm's utility + config_dict = get_hf_file_to_dict("model_index.json", model_name) + # Verify it has the required fields for a diffusers model + if config_dict.get("_class_name") and config_dict.get("_diffusers_version"): + logger.debug("Detected diffusion model via model_index.json") + return True + except Exception as e: + logger.debug("Failed to check model_index.json via get_hf_file_to_dict: %s", e) + + # Strategy 3: Try the standard diffusers approach (may fail due to import issues) + # This is last because it requires importing diffusers/xformers/flash_attn + # which may have compatibility issues + try: + load_diffusers_config(model_name) + return True + except (ImportError, ModuleNotFoundError) as e: + logger.debug("Failed to import diffusers dependencies: %s", e) + logger.debug("This may be due to flash_attn/PyTorch version mismatch") + except Exception as e: + logger.debug("Failed to load diffusers config via DiffusionPipeline: %s", e) + + # Bagel is not a diffusers pipeline (no model_index.json), but is still a + # diffusion-style model in vllm-omni. Detect it via config.json. + return _looks_like_bagel(model_name) diff --git a/vllm_omni/diffusion/utils/network_utils.py b/vllm_omni/diffusion/utils/network_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c3085e8f9a9314862bf1624c546fee966b037c3f --- /dev/null +++ b/vllm_omni/diffusion/utils/network_utils.py @@ -0,0 +1,19 @@ +import socket + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except OSError: + return False + except OverflowError: + return False diff --git a/vllm_omni/diffusion/utils/tf_utils.py b/vllm_omni/diffusion/utils/tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44a78804452d25ce0a7ac0c6c81504e740f3d3e6 --- /dev/null +++ b/vllm_omni/diffusion/utils/tf_utils.py @@ -0,0 +1,54 @@ +import inspect +from typing import Any + +from vllm_omni.diffusion.data import TransformerConfig + + +def get_transformer_config_kwargs( + tf_model_config: TransformerConfig, model_class: type[Any] | None = None +) -> dict[str, Any]: + """ + This function extracts parameters from a TransformerConfig instance and filters out internal + diffusers metadata keys (those starting with '_') that should not be passed to model initialization. + Also filters out parameters that are not accepted by the model's __init__ method (e.g., pooled_projection_dim + for QwenImageTransformer2DModel). + + This uses inspect.signature to dynamically detect accepted parameters, making it general for any model class. + Similar to how diffusers' @register_to_config decorator works. + + Args: + tf_model_config: TransformerConfig instance containing model parameters + model_class: Optional model class to inspect for accepted __init__ parameters. + If None, all non-internal parameters are returned (backward compatibility). + + Returns: + dict: Filtered dictionary of parameters suitable for transformer model initialization + """ + # Extract transformer config parameters, filtering out internal diffusers metadata + # TransformerConfig stores params in a 'params' dict, and we need to exclude + # internal keys like '_class_name' and '_diffusers_version' + tf_config_params = tf_model_config.to_dict() + + # Filter out internal diffusers metadata keys that start with '_' + filtered_params = {k: v for k, v in tf_config_params.items() if not k.startswith("_")} + + # If model_class is provided, use inspect.signature to get accepted parameters + if model_class is not None: + try: + # Get the signature of the model's __init__ method + sig = inspect.signature(model_class.__init__) + # Get all parameter names (excluding 'self' and special parameters) + accepted_params = { + name + for name, param in sig.parameters.items() + if name != "self" and param.kind != inspect.Parameter.VAR_KEYWORD # Exclude **kwargs + } + + # Filter to only include parameters that are in the model's signature + filtered_params = {k: v for k, v in filtered_params.items() if k in accepted_params} + except (TypeError, AttributeError): + # If inspection fails, fall back to returning all non-internal params + # This maintains backward compatibility + pass + + return filtered_params diff --git a/vllm_omni/diffusion/worker/__init__.py b/vllm_omni/diffusion/worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8af0283857f54f0e978b7609c84ea1f9b14045a9 --- /dev/null +++ b/vllm_omni/diffusion/worker/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Worker classes for diffusion models.""" + +from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.diffusion.worker.diffusion_worker import ( + DiffusionWorker, + WorkerProc, +) + +__all__ = [ + "DiffusionModelRunner", + "DiffusionWorker", + "WorkerProc", +] diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..38f438062650d222445a5ddb660f83fc4922b8ac --- /dev/null +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Diffusion Model Runner for vLLM-Omni. + +Handles model loading, compilation, caching, and execution of diffusion model +forward passes. This follows the AR pattern where the Runner handles all +model-related operations. +""" + +from __future__ import annotations + +import time +from collections.abc import Iterable +from contextlib import nullcontext + +import torch +from torch.profiler import record_function +from vllm.config import LoadConfig +from vllm.logger import init_logger +from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes + +from vllm_omni.diffusion.cache.cache_dit_backend import cache_summary +from vllm_omni.diffusion.cache.selector import get_cache_backend +from vllm_omni.diffusion.compile import regionally_compile +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.offload import apply_offload_hooks +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) + + +class DiffusionModelRunner: + """ + Model runner that handles model loading and execution for diffusion models. + + This class follows the AR pattern where the Runner handles all model-related + operations including loading, compilation, offloading, caching, and execution. + The Worker only handles infrastructure (device, distributed env). + """ + + def __init__( + self, + vllm_config, + od_config: OmniDiffusionConfig, + device: torch.device, + ): + """ + Initialize the diffusion model runner. + + Args: + vllm_config: vLLM configuration. + od_config: OmniDiffusion configuration. + device: The device to run on. + """ + self.vllm_config = vllm_config + self.od_config = od_config + self.device = device + self.pipeline = None + self.cache_backend = None + + # Initialize KV cache manager for connector management + self.kv_transfer_manager = OmniKVTransferManager.from_od_config(od_config) + + def load_model( + self, + memory_pool_context_fn: callable | None = None, + ) -> None: + """ + Load the diffusion model, apply compilation and offloading. + + Args: + memory_pool_context_fn: Optional function that returns a context manager + for memory pool allocation (used for sleep mode). + """ + load_device = ( + "cpu" if self.od_config.enable_cpu_offload or self.od_config.enable_layerwise_offload else str(self.device) + ) + + def get_memory_context(): + if memory_pool_context_fn is not None: + return memory_pool_context_fn(tag="weights") + return nullcontext() + + # Load model within forward context + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + load_config = LoadConfig() + model_loader = DiffusersPipelineLoader(load_config) + time_before_load = time.perf_counter() + + with get_memory_context(): + with DeviceMemoryProfiler() as m: + self.pipeline = model_loader.load_model( + od_config=self.od_config, + load_device=load_device, + ) + time_after_load = time.perf_counter() + + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + m.consumed_memory / GiB_bytes, + time_after_load - time_before_load, + ) + logger.info("Model runner: Model loaded successfully.") + + # Apply CPU offloading + if self.od_config.enable_cpu_offload or self.od_config.enable_layerwise_offload: + apply_offload_hooks(self.pipeline, self.od_config, device=self.device) + + # Apply torch.compile if not in eager mode + if not self.od_config.enforce_eager: + if current_omni_platform.supports_torch_inductor(): + try: + self.pipeline.transformer = regionally_compile( + self.pipeline.transformer, + dynamic=True, + ) + logger.info("Model runner: Model compiled with torch.compile.") + except Exception as e: + logger.warning(f"Model runner: torch.compile failed with error: {e}. Using eager mode.") + else: + logger.warning( + "Model runner: Platform %s does not support torch inductor, skipping torch.compile.", + current_omni_platform.get_torch_device(), + ) + + # Setup cache backend + self.cache_backend = get_cache_backend(self.od_config.cache_backend, self.od_config.cache_config) + + if self.cache_backend is not None: + self.cache_backend.enable(self.pipeline) + + logger.info("Model runner: Initialization complete.") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights into the pipeline.""" + return self.pipeline.load_weights(weights) + + @torch.inference_mode() + def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: + """ + Execute a forward pass for the given requests. + + Args: + req: A diffusion request containing a list of prompts to process. + + Returns: + DiffusionOutput with generated results. + """ + assert self.pipeline is not None, "Model not loaded. Call load_model() first." + if len(req.prompts) == 0: + raise ValueError("Cannot execute model with empty request list") + + # The manager handles the check for need_recv_cache internally + self.kv_transfer_manager.receive_kv_cache(req, target_device=getattr(self.pipeline, "device", None)) + + if req.sampling_params.generator is None and req.sampling_params.seed is not None: + req.sampling_params.generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) + + # Refresh cache context if needed + if ( + not getattr(req, "skip_cache_refresh", False) + and self.cache_backend is not None + and self.cache_backend.is_enabled() + ): + self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps) + + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + with record_function("pipeline_forward"): + output = self.pipeline.forward(req) + + # NOTE: + if self.od_config.cache_backend == "cache_dit" and self.od_config.enable_cache_dit_summary: + cache_summary(self.pipeline, details=True) + + return output diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..0fad90a5b38d6e9c50967dcbd78a0dd5bb2a583a --- /dev/null +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -0,0 +1,401 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Diffusion Worker for vLLM-Omni. + +Handles GPU infrastructure initialization and delegates model operations +to DiffusionModelRunner. +""" + +import multiprocessing as mp +import os +from contextlib import AbstractContextManager, nullcontext + +import torch +import zmq +from vllm.config import VllmConfig +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.logger import init_logger +from vllm.utils.mem_utils import GiB_bytes + +from vllm_omni.diffusion.data import ( + DiffusionOutput, + OmniDiffusionConfig, +) +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager +from vllm_omni.diffusion.profiler import CurrentProfiler +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.lora.request import LoRARequest +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) + + +class DiffusionWorker: + """ + A worker that manages GPU infrastructure and delegates to the model runner. + + This class handles infrastructure initialization only: + - Device setup (CUDA device selection) + - Distributed environment (NCCL, model parallel) + - Memory management (sleep/wake) + + All model-related operations (loading, compilation, execution) are + delegated to DiffusionModelRunner. + """ + + def __init__( + self, + local_rank: int, + rank: int, + od_config: OmniDiffusionConfig, + ): + self.local_rank = local_rank + self.rank = rank + self.od_config = od_config + self.device: torch.device | None = None + self.vllm_config: VllmConfig | None = None + self.model_runner: DiffusionModelRunner | None = None + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + self.lora_manager: DiffusionLoRAManager | None = None + self.init_device() + + def init_device(self) -> None: + """Initialize the device and distributed environment.""" + world_size = self.od_config.num_gpus + rank = self.rank + + # Set environment variables for distributed initialization + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(self.od_config.master_port) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + # Setup device + self.device = current_omni_platform.get_torch_device(rank) + current_omni_platform.set_device(self.device) + + # Create vllm_config for parallel configuration + vllm_config = VllmConfig() + vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size + vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size + self.vllm_config = vllm_config + + # Initialize distributed environment + with set_forward_context(vllm_config=vllm_config, omni_diffusion_config=self.od_config): + init_distributed_environment(world_size=world_size, rank=rank) + logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") + + parallel_config = self.od_config.parallel_config + initialize_model_parallel( + data_parallel_size=parallel_config.data_parallel_size, + cfg_parallel_size=parallel_config.cfg_parallel_size, + sequence_parallel_size=parallel_config.sequence_parallel_size, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_size=parallel_config.tensor_parallel_size, + pipeline_parallel_size=parallel_config.pipeline_parallel_size, + ) + + # Create model runner and load model + self.model_runner = DiffusionModelRunner( + vllm_config=self.vllm_config, + od_config=self.od_config, + device=self.device, + ) + self.model_runner.load_model( + memory_pool_context_fn=self._maybe_get_memory_pool_context, + ) + assert self.model_runner.pipeline is not None + self.lora_manager = DiffusionLoRAManager( + pipeline=self.model_runner.pipeline, + device=self.device, + dtype=self.od_config.dtype, + max_cached_adapters=self.od_config.max_cpu_loras, + lora_path=self.od_config.lora_path, + lora_scale=self.od_config.lora_scale, + ) + logger.info(f"Worker {self.rank}: Initialization complete.") + + def generate(self, request: OmniDiffusionRequest) -> DiffusionOutput: + """Generate output for the given requests.""" + return self.execute_model(request, self.od_config) + + @classmethod + def start_profile(cls, trace_path_template: str) -> str: + """Start profiling for this GPU worker.""" + return CurrentProfiler.start(trace_path_template) + + @classmethod + def stop_profile(cls) -> dict | None: + """Stop profiling and return the result dictionary.""" + return CurrentProfiler.stop() + + def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput: + """Execute a forward pass by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + if self.lora_manager is not None: + try: + self.lora_manager.set_active_adapter(req.sampling_params.lora_request, req.sampling_params.lora_scale) + except Exception as exc: + if req.sampling_params.lora_request is not None: + raise + logger.warning("LoRA activation skipped: %s", exc) + return self.model_runner.execute_model(req) + + def load_weights(self, weights) -> set[str]: + """Load weights by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + return self.model_runner.load_weights(weights) + + def remove_lora(self, adapter_id: int) -> bool: + return self.lora_manager.remove_adapter(adapter_id) + + def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + return self.lora_manager.add_adapter(lora_request, lora_scale) + + def list_loras(self) -> list[int]: + return self.lora_manager.list_adapters() + + def pin_lora(self, adapter_id: int) -> bool: + return self.lora_manager.pin_adapter(adapter_id) + + def sleep(self, level: int = 1) -> bool: + """ + Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + """ + from vllm.device_allocator.cumem import CuMemAllocator + + free_bytes_before_sleep = current_omni_platform.get_free_memory() + + # Save the buffers before level 2 sleep + if level == 2 and self.model_runner is not None: + model = self.model_runner.pipeline + self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()} + + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) + free_bytes_after_sleep = current_omni_platform.get_free_memory() + device_id = self.device.index if self.device.index is not None else 0 + total = current_omni_platform.get_device_total_memory(device_id) + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) + return True + + def wake_up(self, tags: list[str] | None = None) -> bool: + """ + Wake up the worker from sleep mode. See the sleep function + method for more details. + + Args: + tags: An optional list of tags to reallocate the worker memory + for specific memory allocations. Values must be in + `("weights")`. If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + worker is used again. + """ + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags) + + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers) and self.model_runner is not None: + model = self.model_runner.pipeline + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + return True + + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + """Get memory pool context for sleep mode support.""" + if self.od_config.enable_sleep_mode: + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process." + return allocator.use_memory_pool(tag=tag) + else: + return nullcontext() + + def shutdown(self) -> None: + """Shutdown the worker and cleanup distributed environment.""" + destroy_distributed_env() + + +class WorkerProc: + """Wrapper that runs one Worker in a separate process.""" + + def __init__( + self, + od_config: OmniDiffusionConfig, + gpu_id: int, + broadcast_handle, + ): + self.od_config = od_config + + # Inter-process Communication + self.context = zmq.Context(io_threads=2) + + # Initialize MessageQueue reader from handle + self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) + + self.result_mq = None + self.result_mq_handle = None + + # Setup result sender (only for rank 0) + if gpu_id == 0: + self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0]) + self.result_mq_handle = self.result_mq.export_handle() + logger.info(f"Worker {gpu_id} created result MessageQueue") + + assert od_config.master_port is not None + self.worker = self._create_worker(gpu_id, od_config) + self.gpu_id = gpu_id + self._running = True + + def _create_worker(self, gpu_id: int, od_config: OmniDiffusionConfig) -> DiffusionWorker: + """Create a worker instance. Override in subclasses for different worker types.""" + return DiffusionWorker( + local_rank=gpu_id, + rank=gpu_id, + od_config=od_config, + ) + + def return_result(self, output: DiffusionOutput): + """Reply to client, only on rank 0.""" + if self.result_mq is not None: + self.result_mq.enqueue(output) + + def recv_message(self): + """Receive messages from broadcast queue.""" + return self.mq.dequeue(indefinite=True) + + def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]: + """Execute an RPC request and indicate whether to reply.""" + method = rpc_request["method"] + args = rpc_request.get("args", ()) + kwargs = rpc_request.get("kwargs", {}) + output_rank = rpc_request.get("output_rank") + exec_all_ranks = rpc_request.get("exec_all_ranks", False) + + should_execute = exec_all_ranks or output_rank is None or output_rank == self.gpu_id + should_reply = (output_rank is None or output_rank == self.gpu_id) and self.result_mq is not None + + if not should_execute: + return None, False + + try: + if isinstance(method, str): + func = getattr(self.worker, method) + result = func(*args, **kwargs) + else: + result = method(self.worker, *args, **kwargs) + return result, should_reply + except Exception as e: + logger.error(f"Error executing RPC: {e}", exc_info=True) + raise e + + def worker_busy_loop(self) -> None: + """Main busy loop for Multiprocessing Workers.""" + logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") + + while self._running: + msg = None + try: + msg = self.recv_message() + except Exception as e: + logger.error( + f"Error receiving message in worker loop: {e}", + exc_info=True, + ) + continue + + if msg is None or len(msg) == 0: + logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) + continue + + # Route message based on type + if isinstance(msg, dict) and msg.get("type") == "rpc": + try: + result, should_reply = self.execute_rpc(msg) + if should_reply: + self.return_result(result) + except Exception as e: + logger.error(f"Error processing RPC: {e}", exc_info=True) + if self.result_mq is not None: + self.return_result(DiffusionOutput(error=str(e))) + + elif isinstance(msg, dict) and msg.get("type") == "shutdown": + logger.info("Worker %s: Received shutdown message", self.gpu_id) + self._running = False + continue + + else: + # Handle generation request + try: + output = self.worker.execute_model(msg, self.od_config) + except Exception as e: + logger.error( + f"Error executing forward in event loop: {e}", + exc_info=True, + ) + output = DiffusionOutput(error=str(e)) + + try: + self.return_result(output) + except zmq.ZMQError as e: + logger.error(f"ZMQ error sending reply: {e}") + continue + + logger.info("event loop terminated.") + try: + self.worker.shutdown() + except Exception as exc: + logger.warning("Worker %s: Shutdown encountered an error: %s", self.gpu_id, exc) + self.context.term() + + @staticmethod + def worker_main( + rank: int, + od_config: OmniDiffusionConfig, + pipe_writer: mp.connection.Connection, + broadcast_handle, + ) -> None: + """Worker initialization and execution loops.""" + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins() + worker_proc = WorkerProc( + od_config, + gpu_id=rank, + broadcast_handle=broadcast_handle, + ) + logger.info(f"Worker {rank}: Scheduler loop started.") + pipe_writer.send( + { + "status": "ready", + "result_handle": worker_proc.result_mq_handle if rank == 0 else None, + } + ) + worker_proc.worker_busy_loop() + logger.info(f"Worker {rank}: Shutdown complete.") diff --git a/vllm_omni/distributed/__init__.py b/vllm_omni/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7634219f18288141422cfeb02ff29d73c872b1 --- /dev/null +++ b/vllm_omni/distributed/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .omni_connectors import ( + ConnectorSpec, + MooncakeConnector, + OmniConnectorBase, + OmniConnectorFactory, + OmniTransferConfig, + SharedMemoryConnector, + YuanrongConnector, + load_omni_transfer_config, +) + +__all__ = [ + # Config + "ConnectorSpec", + "OmniTransferConfig", + # Connectors + "OmniConnectorBase", + "OmniConnectorFactory", + "MooncakeConnector", + "SharedMemoryConnector", + "YuanrongConnector", + # Utilities + "load_omni_transfer_config", +] diff --git a/vllm_omni/distributed/omni_connectors/__init__.py b/vllm_omni/distributed/omni_connectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fcec32eaae381117b167f27247b9dd633c7a0b0b --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/__init__.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .connectors.base import OmniConnectorBase +from .connectors.mooncake_connector import MooncakeConnector +from .connectors.shm_connector import SharedMemoryConnector +from .connectors.yuanrong_connector import YuanrongConnector +from .factory import OmniConnectorFactory +from .utils.config import ConnectorSpec, OmniTransferConfig +from .utils.initialization import ( + build_stage_connectors, + get_connectors_config_for_stage, + get_stage_connector_config, + initialize_connectors_from_config, + initialize_orchestrator_connectors, + load_omni_transfer_config, +) + +__all__ = [ + # Config + "ConnectorSpec", + "OmniTransferConfig", + # Base classes and implementations + "OmniConnectorBase", + # Factory + "OmniConnectorFactory", + # Specific implementations + "MooncakeConnector", + "SharedMemoryConnector", + "YuanrongConnector", + # Utilities + "load_omni_transfer_config", + "initialize_connectors_from_config", + "get_connectors_config_for_stage", + # Manager helpers + "initialize_orchestrator_connectors", + "get_stage_connector_config", + "build_stage_connectors", +] diff --git a/vllm_omni/distributed/omni_connectors/adapter.py b/vllm_omni/distributed/omni_connectors/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..977ecc88b2b82315f45b2940e0e04722ba3ee946 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/adapter.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# temporary for compatibility with vllm_omni.entrypoints.omni_stage.py +# and vllm_omni.entrypoints.omni_llm.py + +import time +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import torch +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request, RequestStatus + +if TYPE_CHECKING: + from .connectors.base import OmniConnectorBase + +from vllm_omni.entrypoints.stage_utils import OmniStageTaskType + +from .utils.logging import get_connector_logger + +logger = get_connector_logger(__name__) + + +def try_send_via_connector( + connector: Any, + stage_id: int, + next_stage_id: int, + req_id: str, + next_inputs: Any, + sampling_params: Any, + original_prompt: Any, + next_stage_queue_submit_fn: Callable[[dict[str, Any]], None], + metrics: Any, +) -> bool: + """ + Attempts to send data via OmniConnector. + Returns True if successful, False otherwise. + Encapsulates the logic of preparing payload, sending via connector, + sending notification, and recording metrics. + """ + try: + t0 = time.time() + + # Prepare data for connector + payload_data = { + "engine_inputs": next_inputs, + "sampling_params": sampling_params, + "metadata": { + "original_prompt": original_prompt, + "stage_transition": f"{stage_id}->{next_stage_id}", + "timestamp": time.time(), + }, + } + + # Send data via connector + success, serialized_size, metadata = connector.put(str(stage_id), str(next_stage_id), str(req_id), payload_data) + + if success: + # Send lightweight notification via queue + notify_payload = { + "type": OmniStageTaskType.GENERATE, + "request_id": req_id, + "sampling_params": sampling_params, + "from_connector": True, + "from_stage": str(stage_id), + "to_stage": str(next_stage_id), + "sent_ts": time.time(), + } + # Merge connector metadata (e.g. shm handle or inline data) into queue payload + if metadata: + notify_payload["connector_metadata"] = metadata + + next_stage_queue_submit_fn(notify_payload) + + t1 = time.time() + tx_ms = (t1 - t0) * 1000.0 + + metrics.on_forward( + stage_id, + next_stage_id, + req_id, + serialized_size, # Use size from connector + float(tx_ms), + True, # Mark as using connector + ) + return True + else: + # If put returned False, we let the caller handle fallback + return False + + except Exception as e: + logger.warning( + "[Orchestrator] OmniConnector failed for req %s: %s; falling back to queue", + req_id, + e, + ) + return False + + +def try_recv_via_connector( + task: dict[str, Any], + connectors: dict[Any, Any], + stage_id: int, +) -> tuple[Any, dict[str, Any] | None]: + """ + Attempts to resolve input data from either connector or IPC. + Returns (engine_inputs, rx_metrics) or (None, None) if failed/skipped. + """ + rid = task["request_id"] + + if task.get("from_connector"): + from_stage = task.get("from_stage") + to_stage = str(stage_id) + + if not from_stage: + logger.error( + "[Stage-%s] 'from_connector' is true but 'from_stage' is missing for request %s", stage_id, rid + ) + return None, None + + # Get connector for this edge + connector_key = (from_stage, to_stage) + connector = connectors.get(connector_key) + + if connector: + try: + # Get data from connector with timeout + _t_start = time.time() + connector_metadata = task.get("connector_metadata") + payload = connector.get(from_stage, to_stage, str(rid), metadata=connector_metadata) + _t_end = time.time() + + if payload: + if isinstance(payload, tuple): + payload_data, serialized_size = payload + else: + payload_data = payload + serialized_size = len(connector.serialize_obj(payload_data)) + else: + payload_data = None + serialized_size = 0 + + if payload_data and isinstance(payload_data, dict): + ein = payload_data.get("engine_inputs") + decode_ms = (_t_end - _t_start) * 1000.0 + + rx_metrics = {"rx_decode_time_ms": decode_ms, "rx_transfer_bytes": serialized_size} + return ein, rx_metrics + else: + logger.error( + "[Stage-%s] Failed to get data from connector for request %s or payload is empty", stage_id, rid + ) + return None, None + except Exception as e: + logger.error("[Stage-%s] Error retrieving data from connector for request %s: %s", stage_id, rid, e) + return None, None + else: + logger.error( + "[Stage-%s] No connector found for edge %s -> %s for request %s", stage_id, from_stage, to_stage, rid + ) + return None, None + else: + # Data comes from queue as usual (e.g. seed request for Stage-0) + # Since fallback logic is deprecated, we assume this is a direct inputs payload. + # We still need to decode it if it used SHM (via legacy stage_utils logic, or new shm_connector format) + # For Stage-0 specifically, 'engine_inputs' is often directly in the task dict. + + # Try to use the new stage_utils which uses OmniSerializer + from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc_with_metrics + + try: + ein, metrics = maybe_load_from_ipc_with_metrics(task, "engine_inputs", "engine_inputs_shm") + # If metrics are empty or zero, we might want to populate dummy metrics + return ein, metrics + except Exception: + # If engine_inputs is missing, it might be a different kind of payload, + # but for Stage-0 seed it should be there. + # We'll return None to let caller handle error if strictly required. + return None, None + + +def get_chunk( + connector: "OmniConnectorBase", + scheduler_output: SchedulerOutput, +) -> None: + """Retrieve a chunk of pooling output. + + Args: + connector: OmniConnectorBase instance + scheduler_output: Partial scheduler output dictionary + + Returns: + None: This function modifies scheduler_output in place + """ + stage_id = connector.stage_id + if stage_id == 0: + return + + target_stage_id = stage_id - 1 + # Handle new requests + for new_req_data in scheduler_output.scheduled_new_reqs: + connector.request_ids_mapping[new_req_data.req_id] = new_req_data.external_req_id + req_id = new_req_data.external_req_id + chunk_id = connector.get_requests[req_id] + connector_get_key = f"{req_id}_{target_stage_id}_{chunk_id}" + payload_data = get_through_connector(connector, target_stage_id, stage_id, req_id, connector_get_key) + if payload_data: + new_req_data.additional_information = payload_data + if payload_data.get("finished"): + connector.finished_requests.add(req_id) + + # Handle cached/running requests + cached_reqs = scheduler_output.scheduled_cached_reqs + if not hasattr(cached_reqs, "additional_information"): + cached_reqs.additional_information = {} + + for i, cached_req_id in enumerate(cached_reqs.req_ids): + req_id = connector.request_ids_mapping.get(cached_req_id, cached_req_id) + if req_id in connector.finished_requests: + continue + chunk_id = connector.get_requests[req_id] + connector_get_key = f"{req_id}_{target_stage_id}_{chunk_id}" + payload_data = get_through_connector(connector, target_stage_id, stage_id, req_id, connector_get_key) + if payload_data: + cached_reqs.additional_information[cached_req_id] = payload_data + if payload_data.get("finished"): + connector.finished_requests.add(req_id) + + +def get_through_connector(connector, target_stage_id, stage_id, req_id, connector_get_key): + # Wait for data from previous stage + import time + + # TODO: add correct check mechanism for the payload_data + max_wait = 300 + for _ in range(max_wait): + result = connector.get( + from_stage=str(target_stage_id), + to_stage=str(stage_id), + get_key=connector_get_key, + ) + payload_data = None + if result: + payload_data, size = result + if payload_data: + connector.request_prompt_token_ids[req_id] = payload_data.get("thinker_input_ids", []) + connector.get_requests[req_id] += 1 + logger.debug("[Stage-%d] Received one chunk for request %s", stage_id, connector_get_key) + break + time.sleep(0.01) + return payload_data + + +def get_chunk_for_generation( + connector: "OmniConnectorBase", + request: Request, +) -> None: + """Retrieve a chunk of pooling output. + + Args: + connector: OmniConnectorBase instance + request: Request object + + Returns: + None: This function modifies request in place + """ + stage_id = connector.stage_id + target_stage_id = stage_id - 1 + request_id = request.external_req_id + + if request_id in connector.finished_requests: + return + + chunk_id = connector.get_requests[request_id] + connector_get_key = f"{request_id}_{target_stage_id}_{chunk_id}" + payload_data = get_through_connector(connector, target_stage_id, stage_id, request_id, connector_get_key) + if not payload_data: + return + + if payload_data.get("finished"): + connector.finished_requests.add(request_id) + request.status = RequestStatus.FINISHED_STOPPED + + # TODO: remove special handling for prompt token ids ? + if chunk_id == 0: + request.prompt_token_ids = payload_data.get("code_predictor_codes", []) + else: + request.prompt_token_ids += payload_data.get("code_predictor_codes", []) + + +def put_chunk( + connector: "OmniConnectorBase", + pooling_output: dict[str, Any], + request: Request, + custom_process_input_func: Callable[[dict[str, Any], Request], dict[str, Any] | None] | None = None, +) -> None: + """Store a chunk of pooling output. + + Args: + connector: OmniConnectorBase instance + pooling_output: Partial pooling output dictionary + request: Request object + custom_process_input_func: Optional custom function to process input + + Returns: + None: This function sends data via connector + """ + stage_id = connector.stage_id + next_stage_id = stage_id + 1 + request_id = request.external_req_id + prompt_token_ids = request.prompt_token_ids + connector.request_prompt_token_ids[request_id] = prompt_token_ids + chunk_id = connector.put_requests[request_id] + connector_put_key = f"{request_id}_{stage_id}_{chunk_id}" + payload_data = None + + # TODO: add default process_input_func to handle the payload_data ? + if custom_process_input_func: + try: + payload_data = custom_process_input_func( + pooling_output=pooling_output, + request=request, + ) + except Exception as e: + logger.error(f"Failed to use custom_process_input_func for payload extraction: {e}") + + if not payload_data: + logger.warning("[Stage-%d] No payload data to send for request %s", stage_id, request_id) + return + + if stage_id == 0 and chunk_id == 0: + if connector.request_payload.get(request_id) is None: + if not payload_data.get("finished"): + connector.request_payload[request_id] = payload_data + return + else: + save_payload = connector.request_payload.pop(request_id) + payload_data["thinker_embeddings"] = torch.cat( + (save_payload.get("thinker_embeddings"), payload_data.get("thinker_embeddings")), dim=0 + ) + payload_data["thinker_hidden_states"] = torch.cat( + (save_payload.get("thinker_hidden_states"), payload_data.get("thinker_hidden_states")), dim=0 + ) + logger.debug("[Stage-%d] Merged embeddings and hidden states for request %s", stage_id, request_id) + + if stage_id == 1: + # TODO: Make parameters configurable and optimize algorithms + chunk_size = left_context_size = 25 + connector.code_prompt_token_ids[request_id].append(payload_data.get("code_predictor_codes", [])) + length = len(connector.code_prompt_token_ids[request_id]) + chunk_length = length % chunk_size + if chunk_length != 0 and not payload_data.get("finished"): + return + + context_length = chunk_length if chunk_length != 0 else chunk_size + end_index = min(length, left_context_size + context_length) + payload_data["code_predictor_codes"] = ( + torch.tensor(connector.code_prompt_token_ids[request_id][-end_index:]) + .transpose(0, 1) + .reshape(-1) + .tolist() + ) + + success, size, metadata = connector.put( + from_stage=str(stage_id), to_stage=str(next_stage_id), put_key=connector_put_key, data=payload_data + ) + + if success: + connector.put_requests[request_id] += 1 + logger.debug("[Stage-%d] Sent %s", stage_id, connector_put_key) + + +def compute_talker_prompt_ids_length(prompt_ids: list[int]) -> int: + """Compute the length of the talker prompt ids. + + Args: + prompt_ids: The prompt ids tensor. + + Returns: + The length of the talker prompt ids. + """ + im_start_token_id = 151644 + system_token_id = 8948 + user_token_id = 872 + assistant_token_id = 77091 + im_start_indexes = [i for i in range(len(prompt_ids)) if prompt_ids[i] == im_start_token_id] + im_start_indexes.append(len(prompt_ids)) + sum_user_len = 0 + assistant_len = 0 + for i in range(len(im_start_indexes) - 1): + s = im_start_indexes[i] + e = im_start_indexes[i + 1] + role = prompt_ids[s + 1] + if role == system_token_id: + continue + elif role == user_token_id: + sum_user_len += e - s + elif role == assistant_token_id and i == len(im_start_indexes) - 2: + assistant_len += 9 # 3 + 4 + 1 + 1 + else: + pass + + return sum_user_len + assistant_len diff --git a/vllm_omni/distributed/omni_connectors/connectors/__init__.py b/vllm_omni/distributed/omni_connectors/connectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..208f01a7cb5ee04c88d276fec2082cd4e830884b --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/connectors/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/distributed/omni_connectors/connectors/base.py b/vllm_omni/distributed/omni_connectors/connectors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e3562b89681c2d446c3557655eacf628c0d5e4db --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/connectors/base.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Any + +from ..utils.logging import get_connector_logger + +logger = get_connector_logger(__name__) + + +class OmniConnectorBase(ABC): + """Base class for all OmniConnectors.""" + + @abstractmethod + def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[bool, int, dict[str, Any] | None]: + """Store Python object, internal serialization handled by connector. + + Args: + from_stage: Source stage identifier + to_stage: Destination stage identifier + put_key: Unique request identifier + data: Python object to store + + Returns: + tuple: (success: bool, serialized_size: int, metadata: Optional[dict]) + Metadata may contain transport-specific handles or inline data. + """ + pass + + @abstractmethod + def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tuple[Any, int] | None: + """Retrieve Python object and payload size (bytes). + + Args: + from_stage: Source stage identifier + to_stage: Destination stage identifier + get_key: Unique request identifier + + Returns: + Tuple of (Python object, serialized byte size) if found, None otherwise + """ + pass + + @abstractmethod + def cleanup(self, request_id: str) -> None: + """Clean up resources for a request.""" + pass + + @abstractmethod + def health(self) -> dict[str, Any]: + """Return health status and metrics.""" + pass + + @staticmethod + def serialize_obj(obj: Any) -> bytes: + """Serialize a Python object to bytes using centralized serializer.""" + from ..utils.serialization import OmniSerializer + + return OmniSerializer.serialize(obj) + + @staticmethod + def deserialize_obj(data: bytes) -> Any: + """Deserialize bytes to Python object using centralized serializer.""" + from ..utils.serialization import OmniSerializer + + return OmniSerializer.deserialize(data) diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..e28486481c4b1e933cf0e513d9d0d878a4f325b7 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_connector.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from typing import Any + +from ..utils.logging import get_connector_logger +from .base import OmniConnectorBase + +logger = get_connector_logger(__name__) + +try: + from mooncake.store import MooncakeDistributedStore, ReplicateConfig +except ImportError: + try: + from mooncake import MooncakeDistributedStore, ReplicateConfig + except ImportError: + MooncakeDistributedStore = None + ReplicateConfig = None + + +class MooncakeConnector(OmniConnectorBase): + """Mooncake-based distributed connector for OmniConnector.""" + + def __init__(self, config: dict[str, Any]): + if MooncakeDistributedStore is None or ReplicateConfig is None: + raise ImportError( + "Mooncake components (MooncakeDistributedStore/ReplicateConfig) are not available. " + "Please ensure the 'mooncake' package is installed in your environment." + ) + + self.config = config + self.host = config.get("host", "127.0.0.1") + self.metadata = config.get("metadata_server", "http://127.0.0.1:8080/metadata") + self.master = config.get("master", "127.0.0.1:50051") + self.segment = config.get("segment", 512 * 1024 * 1024) # 512MB + self.localbuf = config.get("localbuf", 64 * 1024 * 1024) # 64MB + self.proto = config.get("proto", "tcp") + self.rdma = config.get("rdma", "") + + self.store: MooncakeDistributedStore | None = None + self.pin: ReplicateConfig | None = None + + self._metrics = { + "puts": 0, + "gets": 0, + "bytes_transferred": 0, + "errors": 0, + "timeouts": 0, + } + + self._init_store() + + def _make_key(self, rid: str, from_stage: str, to_stage: str) -> str: + """Generate store key for request between stages.""" + return f"{rid}/{from_stage}_{to_stage}" + + def _init_store(self): + """Initialize Mooncake store.""" + try: + self.store = MooncakeDistributedStore() + rc = self.store.setup( + self.host, self.metadata, self.segment, self.localbuf, self.proto, self.rdma, self.master + ) + if rc != 0: + raise RuntimeError(f"Mooncake setup failed: {rc}") + + self.pin = ReplicateConfig() + self.pin.with_soft_pin = True + logger.info("MooncakeConnector initialized successfully") + except Exception as e: + logger.error("Failed to initialize Mooncake store: %s", e) + raise + + # Use base class serialization methods for consistency + + def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[bool, int, dict[str, Any] | None]: + if not self.store: + logger.error("Store not initialized") + return False, 0, None + + try: + serialized_data = self.serialize_obj(data) + key = self._make_key(put_key, from_stage, to_stage) + self.store.put(key, serialized_data, self.pin) + + self._metrics["puts"] += 1 + self._metrics["bytes_transferred"] += len(serialized_data) + + logger.debug( + "MooncakeConnector: stored %s (%s -> %s) %d bytes", + key, + from_stage, + to_stage, + len(serialized_data), + ) + return True, len(serialized_data), None + + except Exception as e: + self._metrics["errors"] += 1 + logger.error("MooncakeConnector put failed: %s", e) + return False, 0, None + + def get( + self, from_stage: str, to_stage: str, get_key: str, metadata: dict[str, Any] | None = None + ) -> tuple[Any, int] | None: + if not self.store: + logger.error("Store not initialized") + return None + + retries = 20 + sleep_s = 0.05 + key = self._make_key(get_key, from_stage, to_stage) + + for attempt in range(retries): + try: + raw_data = self.store.get(key) + + if raw_data: + data = self.deserialize_obj(raw_data) + self._metrics["gets"] += 1 + payload_size = len(raw_data) + logger.debug( + "MooncakeConnector: retrieved %s (%s -> %s) %d bytes", + key, + from_stage, + to_stage, + payload_size, + ) + return data, payload_size + + except Exception as e: + logger.debug("MooncakeConnector get attempt %s failed: %s", attempt, e) + + if attempt < retries - 1: + time.sleep(sleep_s) + + self._metrics["timeouts"] += 1 + logger.warning("MooncakeConnector: timeout waiting for %s", key) + return None + + def cleanup(self, request_id: str) -> None: + if not self.store: + return + + # Note: Mooncake doesn't have explicit delete, data will be garbage collected + # We could implement a cleanup mechanism by storing deletion markers + logger.debug("MooncakeConnector: cleanup requested for %s (no-op)", request_id) + + def health(self) -> dict[str, Any]: + if not self.store: + return {"status": "unhealthy", "error": "Store not initialized"} + + return { + "status": "healthy", + "host": self.host, + "metadata_server": self.metadata, + "master": self.master, + **self._metrics, + } + + def close(self): + """Clean shutdown.""" + if self.store: + try: + self.store.close() + self.store = None + logger.info("MooncakeConnector closed") + except Exception as e: + logger.error("Error closing Mooncake store: %s", e) diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad32e1c7161646052704be218d5d9e49336b804 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import fcntl +import os +import time +from collections import defaultdict +from typing import Any + +from vllm_omni.entrypoints.stage_utils import shm_read_bytes, shm_write_bytes + +from ..utils.logging import get_connector_logger +from .base import OmniConnectorBase + +logger = get_connector_logger(__name__) + + +class SharedMemoryConnector(OmniConnectorBase): + """ + Connector that uses SharedMemory for large objects and inline data for small objects. + Acts as a unified replacement for the legacy IPC fallback logic. + """ + + def __init__(self, config: dict[str, Any]): + self.config = config + self.stage_id = config.get("stage_id", -1) + self.device = config.get("device", "cuda:0") + self.put_requests: dict[str, int] = defaultdict(int) + self.get_requests: dict[str, int] = defaultdict(int) + self.finished_requests: set[str] = set() + self.request_payload = {} + self.request_prompt_token_ids: dict[str, list[int]] = defaultdict(list) + self.code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list) + self.request_ids_mapping: dict[str, str] = {} + # Default threshold matches legacy behavior (64KB) + self.threshold = int(config.get("shm_threshold_bytes", 65536)) + self._metrics = { + "puts": 0, + "gets": 0, + "bytes_transferred": 0, + "shm_writes": 0, + "inline_writes": 0, + } + + def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[bool, int, dict[str, Any] | None]: + try: + # Always serialize first to check size (and for SHM writing) + # Note: For extremely large objects in "inline" mode (e.g. Ray), + # we might double-serialize if we're not careful, but here we assume + # if it's huge we use SHM, or if Ray, threshold is maxsize. + payload = self.serialize_obj(data) + size = len(payload) + + metadata = {} + # if size > self.threshold: + if True: # TODO: correct put & get logic + # Use Shared Memory + lock_file = f"/dev/shm/shm_{put_key}_lockfile.lock" + with open(lock_file, "w") as lockf: + fcntl.flock(lockf, fcntl.LOCK_EX) + meta = shm_write_bytes(payload, name=put_key) + fcntl.flock(lockf, fcntl.LOCK_UN) + + # meta contains {'name': ..., 'size': ...} + metadata[put_key] = {"shm": meta, "size": size} + self._metrics["shm_writes"] += 1 + else: + # Inline - pass bytes directly to avoid double serialization of the object + # We already serialized it to check size, so we pass the bytes. + # The Queue will pickle these bytes (fast), avoiding re-serializing the complex object. + metadata[put_key] = {"inline_bytes": payload, "size": size} + self._metrics["inline_writes"] += 1 + + self._metrics["puts"] += 1 + self._metrics["bytes_transferred"] += size + + return True, size, metadata + + except Exception as e: + logger.error(f"SharedMemoryConnector put failed for req {put_key}: {e}") + return False, 0, None + + def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tuple[Any, int] | None: + from multiprocessing import shared_memory as shm_pkg + + # Wait for shared memory to be available (with retry logic) + max_retries = 30 + retry_delay = 0.1 # 100ms between retries + shm = None + + for attempt in range(max_retries): + try: + shm = shm_pkg.SharedMemory(name=get_key) + break # Successfully opened, exit retry loop + except FileNotFoundError: + if attempt < max_retries - 1: + time.sleep(retry_delay) + else: + # Max retries reached, return None + logger.warning(f"Shared memory '{get_key}' not found after {max_retries} retries") + return None, 0 + + if shm is None: + return None, 0 + + try: + lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock" + with open(lock_file) as lockf: + fcntl.flock(lockf, fcntl.LOCK_SH) + data_bytes = shm_read_bytes({"name": get_key, "size": shm.size}) + fcntl.flock(lockf, fcntl.LOCK_UN) + # Clean up the temporary file if it still exists. + if os.path.exists(lock_file): + os.remove(lock_file) + obj = self.deserialize_obj(data_bytes) + return obj, shm.size + finally: + shm.close() + + # TODO: update another read method + + def cleanup(self, request_id: str) -> None: + # SHM segments are automatically unlinked during 'get' (shm_read_bytes). + # If 'get' is never called (e.g. error flow), the SHM segment might leak. + # A robust implementation might track created segments and unlink them here + # if they haven't been consumed. + # For now, we rely on the consumer to read and unlink. + pass + + def health(self) -> dict[str, Any]: + return {"status": "healthy", "threshold": self.threshold, **self._metrics} diff --git a/vllm_omni/distributed/omni_connectors/connectors/yuanrong_connector.py b/vllm_omni/distributed/omni_connectors/connectors/yuanrong_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e4c85d5cd719cb50bade1bd39f8e527760aabd --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/connectors/yuanrong_connector.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from ..utils.logging import get_connector_logger +from .base import OmniConnectorBase + +logger = get_connector_logger(__name__) + +try: + from datasystem.kv_client import KVClient, SetParam, WriteMode +except ImportError: + KVClient = None + SetParam = None + WriteMode = None + + +class YuanrongConnector(OmniConnectorBase): + """Datasystem-based distributed connector for OmniConnector.""" + + def __init__(self, config: dict[str, Any]): + if KVClient is None or SetParam is None or WriteMode is None: + raise ImportError( + "Datasystem components (KVClient/SetParam/WriteMode) are not available. " + "Please ensure the 'datasystem' package is installed in your environment." + ) + + self.config = config + self.client = None + self.set_param = SetParam() + self.set_param.write_mode = WriteMode.NONE_L2_CACHE_EVICT + self.get_sub_timeout_ms = max(0, int(self.config.get("get_sub_timeout_ms", 1000))) + + self._metrics = { + "puts": 0, + "gets": 0, + "bytes_transferred": 0, + "errors": 0, + "timeouts": 0, + } + + self._init_client() + + def _make_key(self, rid: str, from_stage: str, to_stage: str) -> str: + """Generate key for request between stages.""" + return f"{rid}:{from_stage}_{to_stage}" + + def _init_client(self): + """Initialize Datasystem client.""" + try: + self.host = self.config.get("host", "127.0.0.1") + self.port = int(self.config.get("port", "35001")) + self.client = KVClient(self.host, self.port) + self.client.init() + + logger.info("YuanrongConnector initialized successfully") + except Exception as e: + logger.error("Failed to initialize Datasystem client: %s", e) + raise + + def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[bool, int, dict[str, Any] | None]: + if not self.client: + logger.error("Datasystem client not initialized") + return False, 0, None + + try: + serialized_data = self.serialize_obj(data) + key = self._make_key(put_key, from_stage, to_stage) + self.client.set(key, serialized_data, self.set_param.write_mode) + + self._metrics["puts"] += 1 + self._metrics["bytes_transferred"] += len(serialized_data) + + logger.debug( + "YuanrongConnector: stored %s (%s -> %s) %d bytes", + key, + from_stage, + to_stage, + len(serialized_data), + ) + return True, len(serialized_data), None + + except Exception as exc: + self._metrics["errors"] += 1 + logger.error("YuanrongConnector put failed: %s", exc) + return False, 0, None + + def get( + self, from_stage: str, to_stage: str, get_key: str, metadata: dict[str, Any] | None = None + ) -> tuple[Any, int] | None: + if not self.client: + logger.error("Datasystem client not initialized") + return None + + key = self._make_key(get_key, from_stage, to_stage) + try: + raw_list = self.client.get([key], False, self.get_sub_timeout_ms) + raw_data = raw_list[0] if raw_list else None + if raw_data is not None: + data = self.deserialize_obj(raw_data) + self._metrics["gets"] += 1 + payload_size = len(raw_data) + logger.debug( + "YuanrongConnector: retrieved %s (%s -> %s) %d bytes", + key, + from_stage, + to_stage, + payload_size, + ) + return data, payload_size + + except Exception as exc: + self._metrics["timeouts"] += 1 + logger.error("YuanrongConnector get failed: %s", exc) + return None + + def cleanup(self, request_id: str) -> None: + if not self.client: + return + + # Note: Datasystem doesn't have explicit delete, data will be garbage collected + logger.debug("YuanrongConnector: cleanup requested for %s (no-op)", request_id) + + def health(self) -> dict[str, Any]: + if not self.client: + return {"status": "unhealthy", "error": "Datasystem client not initialized"} + + return {"status": "healthy", "host": self.host, "port": self.port, **self._metrics} + + def close(self) -> None: + if not self.client: + return + + self.client = None + logger.info("YuanrongConnector closed") diff --git a/vllm_omni/distributed/omni_connectors/factory.py b/vllm_omni/distributed/omni_connectors/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..34415f7fddeb6f5d26f992e53280d210d96e59e8 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/factory.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable +from typing import Any + +from .utils.logging import get_connector_logger + +try: + from .connectors.base import OmniConnectorBase + from .utils.config import ConnectorSpec +except ImportError: + # Fallback for direct execution + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + from omni_connectors.connectors.base import OmniConnectorBase + from omni_connectors.utils.config import ConnectorSpec + +logger = get_connector_logger(__name__) + + +class OmniConnectorFactory: + """Factory for creating OmniConnectors.""" + + _registry: dict[str, Callable[[dict[str, Any]], OmniConnectorBase]] = {} + + @classmethod + def register_connector(cls, name: str, constructor: Callable[[dict[str, Any]], OmniConnectorBase]) -> None: + """Register a connector constructor.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + cls._registry[name] = constructor + logger.debug(f"Registered connector: {name}") + + @classmethod + def create_connector(cls, spec: ConnectorSpec) -> OmniConnectorBase: + """Create a connector from specification.""" + if spec.name not in cls._registry: + raise ValueError(f"Unknown connector: {spec.name}. Available: {list(cls._registry.keys())}") + + constructor = cls._registry[spec.name] + try: + connector = constructor(spec.extra) + logger.info(f"Created connector: {spec.name}") + return connector + except Exception as e: + logger.error(f"Failed to create connector {spec.name}: {e}") + raise ValueError(f"Failed to create connector {spec.name}: {e}") + + @classmethod + def list_registered_connectors(cls) -> list[str]: + """List all registered connector names.""" + return list(cls._registry.keys()) + + +# Register built-in connectors with lazy imports +def _create_mooncake_connector(config: dict[str, Any]) -> OmniConnectorBase: + try: + from .connectors.mooncake_connector import MooncakeConnector + except ImportError: + # Fallback import + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + from omni_connectors.connectors.mooncake_connector import MooncakeConnector + return MooncakeConnector(config) + + +def _create_shm_connector(config: dict[str, Any]) -> OmniConnectorBase: + try: + from .connectors.shm_connector import SharedMemoryConnector + except ImportError: + # Fallback import + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + from omni_connectors.connectors.shm_connector import SharedMemoryConnector + return SharedMemoryConnector(config) + + +def _create_yuanrong_connector(config: dict[str, Any]) -> OmniConnectorBase: + try: + from .connectors.yuanrong_connector import YuanrongConnector + except ImportError: + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + from omni_connectors.connectors.yuanrong_connector import YuanrongConnector + return YuanrongConnector(config) + + +# Register connectors +OmniConnectorFactory.register_connector("MooncakeConnector", _create_mooncake_connector) +OmniConnectorFactory.register_connector("SharedMemoryConnector", _create_shm_connector) +OmniConnectorFactory.register_connector("YuanrongConnector", _create_yuanrong_connector) diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..82c06fdafeec84281e7f2a4e6a25b5185a36acf2 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unified OmniConnector and KV cache transfer management.""" + +import time +from collections.abc import Callable +from dataclasses import asdict, dataclass +from typing import Any + +import torch +from vllm.logger import init_logger + +from .factory import OmniConnectorFactory +from .utils.config import ConnectorSpec + +logger = init_logger(__name__) + + +@dataclass +class OmniKVCacheConfig: + """Configuration for OmniKVTransferManager.""" + + connector_config: dict[str, Any] | None = None + from_stage: str | None = None + to_stage: str | None = None + stage_id: str | int | None = None + engine_input_source: list[str | int] | None = None + need_recv_cache: bool = False + need_send_cache: bool = False + recv_timeout: float = 30.0 + + +@dataclass +class KVCacheTransferData: + """Container for KV cache transfer data.""" + + request_id: str + layer_blocks: dict[str, Any] + block_ids: list[int] + metadata: dict[str, Any] + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + +class OmniKVTransferManager: + """Unified management for OmniConnector and KV cache transfer. + + This class encapsulates all KV cache related operations: + - Connector initialization and lazy creation + - KV cache extraction from GPU blocks + - KV cache transfer with retry logic + - KV cache receiving with timeout + """ + + def __init__(self, config: OmniKVCacheConfig): + self.config = config + self._connector = None + + # Pre-calculate send stages (from_stage, to_stage) + self.send_stages = ( + (str(config.from_stage), str(config.to_stage)) if config.from_stage and config.to_stage else (None, None) + ) + + # Pre-calculate receive stages (from_stage, to_stage) + recv_from = config.from_stage + if config.engine_input_source: + recv_from = config.engine_input_source[0] + elif isinstance(config.stage_id, int): + recv_from = config.stage_id - 1 + + self.recv_stages = ( + (str(recv_from), str(config.stage_id)) + if recv_from is not None and config.stage_id is not None + else (None, None) + ) + + @classmethod + def _create(cls, cfg: dict | None) -> "OmniKVTransferManager": + """Create manager from raw config dict.""" + if not cfg or not isinstance(cfg, dict): + return cls(OmniKVCacheConfig()) + return cls( + OmniKVCacheConfig( + connector_config=cfg.get("connector_config"), + from_stage=cfg.get("omni_from_stage"), + to_stage=cfg.get("omni_to_stage"), + stage_id=cfg.get("stage_id"), + engine_input_source=cfg.get("engine_input_source", []), + need_recv_cache=cfg.get("need_recv_cache", False), + need_send_cache=cfg.get("need_send_cache", False), + recv_timeout=cfg.get("recv_timeout", 30.0), + ) + ) + + @classmethod + def from_model_config(cls, config: Any) -> "OmniKVTransferManager": + """Create from model config (for AR model runner).""" + return cls._create(getattr(config, "omni_kv_config", None)) + + @classmethod + def from_od_config(cls, config: Any) -> "OmniKVTransferManager": + """Create from OmniDiffusion config (for diffusion runner).""" + return cls._create(getattr(config, "omni_kv_config", None)) + + @classmethod + def from_vllm_config(cls, vllm_config: Any, model_config: Any) -> "OmniKVTransferManager": + """Create from vllm config with fallback to kv_transfer_config.""" + # Primary: omni_kv_config from model_config + omni_kv = getattr(model_config, "omni_kv_config", None) + if isinstance(omni_kv, dict): + return cls._create(omni_kv) + + # Fallback: check kv_transfer_config + kv_cfg = getattr(vllm_config, "kv_transfer_config", None) + if kv_cfg: + direct = getattr(kv_cfg, "omni_connector_config", None) + if isinstance(direct, dict) and direct: + return cls._create({"connector_config": direct}) + extra = getattr(kv_cfg, "kv_connector_extra_config", None) + if isinstance(extra, dict): + omni = extra.get("omni_connector_config") + if isinstance(omni, dict) and omni: + return cls._create({"connector_config": omni}) + + return cls(OmniKVCacheConfig()) + + @property + def connector(self): + """Lazy initialization of connector.""" + # If a previous initialization attempt failed, don't retry on every access. + if self._connector is False: + return None + + if self._connector is None: + cfg = self.config.connector_config + if cfg and (c_type := cfg.get("type")): + try: + logger.info(f"Initializing OmniConnector with config: {cfg}") + c_extra = {k: v for k, v in cfg.items() if k != "type"} + self._connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra)) + except Exception as e: + logger.error(f"Failed to initialize OmniConnector: {e}") + import traceback + + traceback.print_exc() + # Cache failure sentinel to avoid repeated initialization attempts in hot paths. + self._connector = False + + return self._connector if self._connector else None + + def get_connector(self): + """Get connector (compatibility wrapper for existing code).""" + return self.connector + + def handle_finished_requests_kv_transfer( + self, + finished_reqs: dict[str, dict[str, Any]], + kv_caches: list[torch.Tensor], + block_size: int, + cache_dtype: str, + request_id_resolver: Callable[[str], str] | None = None, + ) -> list[str]: + """Handle KV cache transfer for finished requests. + + This method extracts KV cache from GPU blocks and transfers them + to the downstream stage via the connector. + + Args: + finished_reqs: Dict mapping request_id to {block_ids, seq_len} + kv_caches: List of KV cache tensors per layer + block_size: Size of each cache block + cache_dtype: Data type of the cache + request_id_resolver: Optional function to resolve global request ID + + Returns: + List of request IDs that were processed + """ + if not finished_reqs: + return [] + + if not self.config.need_send_cache: + return list(finished_reqs.keys()) + + if not self.connector: + logger.warning("No connector available, skipping KV transfer but freeing resources") + return list(finished_reqs.keys()) + + logger.debug(f"Processing KV transfer for {len(finished_reqs)} requests") + + extracted_ids = [] + for req_id, data in finished_reqs.items(): + try: + seq_len = data.get("seq_len", 0) + block_ids = data.get("block_ids", []) + if not block_ids: + logger.warning(f"Request {req_id} has no block IDs, skipping") + continue + + # Extract KV cache from GPU blocks -> CPU tensors + kv_data = self._extract_kv_cache(req_id, block_ids, seq_len, kv_caches, block_size, cache_dtype) + if kv_data: + # Resolve global request ID if available + transfer_req_id = request_id_resolver(req_id) if request_id_resolver else req_id + + # Transfer to downstream stage via connector + self._transfer_kv_cache(kv_data, transfer_req_id) + + except Exception as e: + logger.error(f"Failed KV transfer for {req_id}: {e}") + finally: + extracted_ids.append(req_id) + + return extracted_ids + + def _extract_kv_cache( + self, + req_id: str, + block_ids: list[int], + seq_len: int, + kv_caches: list[torch.Tensor], + block_size: int, + cache_dtype: str, + ) -> KVCacheTransferData | None: + """Extract KV cache from GPU blocks for a single request. + + Args: + req_id: Request identifier + block_ids: List of block IDs to extract + seq_len: Sequence length + kv_caches: List of KV cache tensors per layer + block_size: Size of each cache block + cache_dtype: Data type of the cache + + Returns: + KVCacheTransferData if extraction successful, None otherwise + """ + num_layers = len(kv_caches) + key_cache: list[torch.Tensor | None] = [None] * num_layers + value_cache: list[torch.Tensor | None] = [None] * num_layers + + for layer_idx, kv_tensor in enumerate(kv_caches): + # Validate block IDs - shape: [2, num_blocks, block_size, n_heads, head_dim] + max_block = kv_tensor.shape[1] - 1 + valid_ids = [bid for bid in block_ids if 0 <= bid <= max_block] + if not valid_ids: + continue + + # Extract and reshape: [2, n_blocks, block_size, n_heads, head_dim] + # -> [2, seq_len, n_heads, head_dim] + selected = kv_tensor[:, valid_ids] # [2, n_valid, block_size, n_heads, head_dim] + n_kv, n_blks, blk_sz, n_heads, d_head = selected.shape + flat = selected.reshape(n_kv, n_blks * blk_sz, n_heads, d_head) + if seq_len < flat.shape[1]: + flat = flat[:, :seq_len] + + # Move to CPU + flat_cpu = flat.detach().cpu().contiguous() + key_cache[layer_idx] = flat_cpu[0] + value_cache[layer_idx] = flat_cpu[1] + + if not any(k is not None for k in key_cache): + return None + + return KVCacheTransferData( + request_id=req_id, + layer_blocks={"key_cache": key_cache, "value_cache": value_cache}, + block_ids=block_ids, + metadata={ + "block_size": block_size, + "num_layers": num_layers, + "dtype": str(cache_dtype), + "seq_len": seq_len, + }, + ) + + def _transfer_kv_cache(self, kv_data: KVCacheTransferData, transfer_req_id: str) -> None: + """Transfer KV cache data to downstream stage via OmniConnector. + + Args: + kv_data: The extracted KV cache data + transfer_req_id: The request ID to use for transfer + """ + from_stage, to_stage = self.send_stages + if not from_stage or not to_stage: + raise ValueError("Transfer stages (omni_from_stage, omni_to_stage) not configured") + + # Prepare data and transfer with retry + data_dict = kv_data.to_dict() + data_dict["request_id"] = transfer_req_id + + success, size, _ = self._transfer_with_retry(from_stage, to_stage, f"kv_cache_{transfer_req_id}", data_dict) + + if success: + logger.info(f"KV transfer OK: {transfer_req_id}, {size} bytes") + else: + logger.error(f"KV transfer FAILED: {transfer_req_id}") + + def _transfer_with_retry( + self, + from_stage: str, + to_stage: str, + request_id: str, + data: dict[str, Any], + max_retries: int = 3, + ) -> tuple[bool, int, dict[str, Any] | None]: + """Transfer data with retry and exponential backoff. + + Args: + from_stage: Source stage identifier + to_stage: Target stage identifier + request_id: Request identifier for the key + data: Data to transfer + max_retries: Maximum number of retry attempts + + Returns: + Tuple of (success, size, metadata) + """ + for attempt in range(max_retries): + try: + # Build the full key for connector + full_request_id = f"omni_{from_stage}_to_{to_stage}_{request_id}" + success, size, metadata = self.connector.put( + from_stage=from_stage, to_stage=to_stage, put_key=full_request_id, data=data + ) + if success: + return success, size, metadata + logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}") + except Exception as e: + logger.warning(f"Transfer attempt {attempt + 1} exception: {e}") + + if attempt < max_retries - 1: + time.sleep(0.1 * (2**attempt)) + + return False, 0, None + + @torch.inference_mode() + def receive_kv_cache_for_request( + self, + request_id: str, + target_device: torch.device | None = None, + ) -> tuple[dict[str, Any] | None, int]: + """Receive KV cache for a specific request. + + This implements the receiving logic from gpu_diffusion_model_runner.py. + + Args: + request_id: The request ID to receive KV cache for + target_device: Optional device to move tensors to + + Returns: + Tuple of (data dict, size) if successful, (None, 0) otherwise + """ + if not self.connector: + logger.warning("No connector available for receiving KV cache") + return None, 0 + + from_stage, to_stage = self.recv_stages + if not from_stage or not to_stage: + logger.warning("Receive stages not configured") + return None, 0 + + # Check if we should receive KV cache based on config + if not self.config.need_recv_cache: + logger.info(f"Skip receiving KV cache for {request_id} (need_recv_cache=False)") + return None, 0 + + timeout = self.config.recv_timeout + start_time = time.time() + + logger.info(f"Wait for KV cache for request {request_id} from stage {from_stage} to {to_stage}...") + + try: + while True: + # Build the full key for connector + full_request_id = f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}" + result = self.connector.get( + from_stage=from_stage, + to_stage=to_stage, + get_key=full_request_id, + ) + if result: + data, size = result + logger.info(f"Successfully received KV cache for {request_id}, {size} bytes") + + # Move tensors to target device if specified + if target_device is not None and isinstance(data, dict) and "layer_blocks" in data: + layer_blocks = data["layer_blocks"] + for cache_list in [ + layer_blocks.get("key_cache", []), + layer_blocks.get("value_cache", []), + ]: + for i, tensor in enumerate(cache_list): + if isinstance(tensor, torch.Tensor) and tensor.device != target_device: + cache_list[i] = tensor.to(target_device).contiguous() + + return data, size + + if time.time() - start_time > timeout: + logger.error(f"Timeout waiting for KV cache for request {request_id} after {timeout}s") + return None, 0 + + time.sleep(0.5) + + except Exception as e: + logger.error(f"Error receiving KV cache for {request_id}: {e}") + import traceback + + traceback.print_exc() + return None, 0 + + def apply_kv_cache_to_request(self, req: Any, data: dict[str, Any]) -> None: + """Apply received KV cache data to a request object. + + Args: + req: The request object to apply KV cache to + data: The received KV cache data dictionary + """ + if isinstance(data, dict) and "layer_blocks" in data: + layer_blocks = data["layer_blocks"] + from types import SimpleNamespace + + kv_obj = SimpleNamespace(**layer_blocks) + req.past_key_values = kv_obj + + # [Omni] Also attach to sampling_params for BagelPipeline compatibility + # BagelPipeline checks req.sampling_params.past_key_values + if hasattr(req, "sampling_params") and req.sampling_params is not None: + req.sampling_params.past_key_values = kv_obj + + if "metadata" in data: + req.kv_metadata = data["metadata"] + + # Legacy compatibility method + def receive_kv_cache(self, req: Any, target_device: torch.device | None = None) -> bool: + """Receive KV cache and populate request object (legacy interface). + + Args: + req: Request object with request_id attribute + target_device: Optional device to move tensors to + + Returns: + True if successful, False otherwise + """ + request_id = getattr(req, "request_id", None) + if not request_id and hasattr(req, "request_ids") and req.request_ids: + # Adaptation for new OmniDiffusionRequest which has list of prompts/ids + request_id = req.request_ids[0] + + if not request_id: + logger.warning("Request has no ID, cannot receive KV cache") + return False + + data, size = self.receive_kv_cache_for_request(request_id, target_device) + if data: + self.apply_kv_cache_to_request(req, data) + return True + return False diff --git a/vllm_omni/distributed/omni_connectors/utils/__init__.py b/vllm_omni/distributed/omni_connectors/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..208f01a7cb5ee04c88d276fec2082cd4e830884b --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/utils/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/distributed/omni_connectors/utils/config.py b/vllm_omni/distributed/omni_connectors/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5556296d938dd69b39dfa41e08eb1e31d15ceb1f --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/utils/config.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from typing import Any + +from .logging import get_connector_logger + +logger = get_connector_logger(__name__) + + +@dataclass +class ConnectorSpec: + """Specification for a connector instance.""" + + name: str # e.g., "MooncakeConnector", "SharedMemoryConnector", "YuanrongConnector" + extra: dict[str, Any] = field(default_factory=dict) # backend-specific config + + +@dataclass +class OmniTransferConfig: + """ + Top-level configuration for OmniConnector system. + Members: + connectors: A dictionary of connectors, keyed by (from_stage, to_stage). + default_connector: The default connector to use if no connector is specified for an edge. + """ + + # Direct mapping: (from_stage, to_stage) -> connector + connectors: dict[tuple[str, str], ConnectorSpec] = field(default_factory=dict) + default_connector: ConnectorSpec | None = None + + def get_connector_for_edge(self, from_stage: str, to_stage: str) -> ConnectorSpec | None: + """Get connector spec for a specific edge.""" + edge_key = (from_stage, to_stage) + return self.connectors.get(edge_key, self.default_connector) + + def has_connector_for_edge(self, from_stage: str, to_stage: str) -> bool: + """Check if there's a connector configured for the edge.""" + return self.get_connector_for_edge(from_stage, to_stage) is not None diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..91322e63eac500529347c834fb123b29f8985692 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Utilities for OmniConnector configuration and validation.""" + +import json +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from ..factory import OmniConnectorFactory +from .config import ConnectorSpec, OmniTransferConfig +from .logging import get_connector_logger + +if TYPE_CHECKING: + from ..connectors.base import OmniConnectorBase +else: + OmniConnectorBase = Any + +logger = get_connector_logger(__name__) + + +def initialize_connectors_from_config( + config_path: str | Path | None = None, default_shm_threshold: int = 65536 +) -> tuple[OmniTransferConfig | None, dict[tuple[str, str], OmniConnectorBase]]: + """ + Initialize connectors from configuration file. + + Returns: + tuple: (OmniTransferConfig, dict of {(from, to): connector_instance}) + """ + transfer_config = load_omni_transfer_config(config_path, default_shm_threshold=default_shm_threshold) + + if not transfer_config: + logger.info("No OmniTransferConfig provided") + return None, {} + + # create connectors from config + connectors = create_connectors_from_config(transfer_config.connectors) + return transfer_config, connectors + + +def create_connectors_from_config( + connectors_config: dict[tuple[str, str], ConnectorSpec], +) -> dict[tuple[str, str], OmniConnectorBase]: + """ + Create connectors from config. + + Args: + connectors_config: A dictionary of connector configurations. + + Returns: + A dictionary of connectors. + """ + connectors = {} + for edge_key, connector_spec in connectors_config.items(): + try: + connector = OmniConnectorFactory.create_connector(connector_spec) + connectors[edge_key] = connector + logger.info(f"Created connector for {edge_key[0]} -> {edge_key[1]}: {type(connector).__name__}") + except Exception as e: + raise RuntimeError(f"Failed to initialize connector for edge {edge_key}: {e}") from e + + return connectors + + +def get_connectors_config_for_stage(transfer_config: OmniTransferConfig | None, stage_id: str | int) -> dict[str, Any]: + """ + Extract connector configurations relevant for a specific stage worker. + + Returns a dict compatible with worker initialization: + { + "from_stage_X": { + "spec": { + "name": "ConnectorName", + "extra": {...} + } + }, + ... + } + """ + if not transfer_config: + return {} + + stage_connectors_config = {} + target_stage = str(stage_id) + + # Iterate through all configured edges + for (from_stage, to_stage), spec in transfer_config.connectors.items(): + # We only care about incoming edges for the worker process + # (Worker needs to create connectors to receive data) + if to_stage == target_stage: + stage_connectors_config[f"from_stage_{from_stage}"] = {"spec": {"name": spec.name, "extra": spec.extra}} + elif from_stage == target_stage and target_stage == "0": + stage_connectors_config[f"to_stage_{to_stage}"] = {"spec": {"name": spec.name, "extra": spec.extra}} + + return stage_connectors_config + + +def load_omni_transfer_config( + config_path: str | Path | None = None, + config_dict: dict[str, Any] | None = None, + default_shm_threshold: int = 65536, +) -> OmniTransferConfig | None: + """Load OmniTransferConfig from file or dict.""" + if config_path is None and config_dict is None: + # Even if no config provided, we might want to return a default config with SHM connectors + # But without stage info we can't do much. + return None + + if config_path is not None: + config_path = Path(config_path) + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + with open(config_path, encoding="utf-8") as f: + if config_path.suffix.lower() == ".json": + config_dict = json.load(f) + elif config_path.suffix.lower() in [".yaml", ".yml"]: + try: + import yaml + + config_dict = yaml.safe_load(f) + except ImportError: + raise ImportError("PyYAML required for YAML config files") + else: + raise ValueError(f"Unsupported config file format: {config_path.suffix}") + + if config_dict is None: + return None + + # Parse connectors + connectors = {} + runtime_config = config_dict.get("runtime", {}) + + # Parse global connectors (from runtime.connectors) + global_connectors = runtime_config.get("connectors", {}) + + # Parse stage-level connectors + stage_args = config_dict.get("stage_args", []) + expected_edges: set[tuple[str, str]] = set() + for stage_config in stage_args: + stage_id = str(stage_config["stage_id"]) + + # Input connectors + for input_key, conn_ref in stage_config.get("input_connectors", {}).items(): + if isinstance(conn_ref, str): + # Reference to global connector + if conn_ref in global_connectors: + conn_config = global_connectors[conn_ref] + connector = ConnectorSpec(name=conn_config["name"], extra=conn_config.get("extra", {})) + else: + raise ValueError(f"Undefined connector reference: {conn_ref}") + else: + # Inline connector definition + connector = ConnectorSpec(name=conn_ref["name"], extra=conn_ref.get("extra", {})) + + # Parse from_stage from key (e.g., "from_stage_0" -> "0") + from_stage = input_key.replace("from_stage_", "") + edge_key = (from_stage, stage_id) + connectors[edge_key] = connector + expected_edges.add(edge_key) + + # Output connectors + for output_key, conn_ref in stage_config.get("output_connectors", {}).items(): + if isinstance(conn_ref, str): + # Reference to global connector + if conn_ref in global_connectors: + conn_config = global_connectors[conn_ref] + connector = ConnectorSpec(name=conn_config["name"], extra=conn_config.get("extra", {})) + else: + raise ValueError(f"Undefined connector reference: {conn_ref}") + else: + # Inline connector definition + connector = ConnectorSpec(name=conn_ref["name"], extra=conn_ref.get("extra", {})) + + # Parse to_stage from key (e.g., "to_stage_1" -> "1") + to_stage = output_key.replace("to_stage_", "") + edge_key = (stage_id, to_stage) + connectors[edge_key] = connector + expected_edges.add(edge_key) + + # Auto-configure SharedMemoryConnector for missing edges based on runtime edges / engine_input_source + if stage_args: + try: + # Prefer explicit runtime edges if provided + runtime_edges = runtime_config.get("edges", []) + if isinstance(runtime_edges, list) and runtime_edges: + for edge in runtime_edges: + from_stage = edge.get("from") + to_stage = edge.get("to") + if from_stage is None or to_stage is None: + continue + edge_key = (str(from_stage), str(to_stage)) + expected_edges.add(edge_key) + if edge_key not in connectors: + logger.info(f"Auto-configuring SharedMemoryConnector for edge {edge_key}") + connectors[edge_key] = ConnectorSpec( + name="SharedMemoryConnector", + extra={"shm_threshold_bytes": default_shm_threshold}, + ) + + # Fallback: infer edges from engine_input_source for each stage + for stage_config in stage_args: + to_stage = str(stage_config["stage_id"]) + # Check explicit input sources + sources = stage_config.get("engine_input_source", []) + + for from_stage in sources: + from_stage_str = str(from_stage) + edge_key = (from_stage_str, to_stage) + expected_edges.add(edge_key) + + if edge_key not in connectors: + logger.info(f"Auto-configuring SharedMemoryConnector for edge {edge_key}") + connectors[edge_key] = ConnectorSpec( + name="SharedMemoryConnector", extra={"shm_threshold_bytes": default_shm_threshold} + ) + + except Exception as e: + logger.warning(f"Failed to auto-configure SHM connectors: {e}") + + # Fail fast if any expected edge is still missing a connector + missing_edges = [edge for edge in expected_edges if edge not in connectors] + if missing_edges: + missing_str = ", ".join([f"{f}->{t}" for f, t in missing_edges]) + raise ValueError( + "Connector configuration missing for edges: " + f"{missing_str}. Define connectors or allow auto SHM creation for these edges." + ) + + config = OmniTransferConfig(connectors=connectors) + + logger.info(f"Loaded OmniTransferConfig with {len(connectors)} connector configurations") + return config + + +# High-level management functions + + +def initialize_orchestrator_connectors( + config_path: str | None, worker_backend: str | None = "multi_process", shm_threshold_bytes: int = 65536 +) -> tuple[OmniTransferConfig | None, dict[tuple[str, str], OmniConnectorBase]]: + """Initialize connectors shared at orchestrator level. + Args: + config_path: The path to the configuration file. + worker_backend: The backend to use for the worker. + Returns: + A tuple containing the OmniTransferConfig and a dictionary of connectors. + """ + if worker_backend == "ray": + default_shm_threshold = sys.maxsize + else: + default_shm_threshold = max(0, shm_threshold_bytes) + transfer_config, connectors = initialize_connectors_from_config( + config_path, default_shm_threshold=default_shm_threshold + ) + return transfer_config, connectors + + +def get_stage_connector_config( + transfer_config: OmniTransferConfig | None, + stage_id: int, +) -> dict[str, Any]: + """Return the serialized connector config payload for a specific stage.""" + if transfer_config is None: + return {} + + try: + return get_connectors_config_for_stage(transfer_config, stage_id) + except Exception as exc: # pragma: no cover - defensive logging + logger.warning( + "Failed to build connector config for stage %s: %s. Using IPC fallback.", + stage_id, + exc, + ) + return {} + + +def build_stage_connectors( + stage_id: int, + connectors_config: dict[str, Any], +) -> dict[tuple[str, str], Any] | None: + """Instantiate OmniConnectors for a stage based on config.""" + if not connectors_config: + return {} + + logger.info( + "[Stage-%s] Initializing OmniConnectors with config keys: %s", + stage_id, + list(connectors_config.keys()), + ) + + from .config import ConnectorSpec + + connectors: dict[tuple[str, str], Any] = {} + # Convert dictionary-formatted config to ConnectorSpec objects + stage_connector_specs = {} + for input_key, config in connectors_config.items(): + if not input_key.startswith("from_stage_"): + continue + + from_stage = input_key.replace("from_stage_", "") + spec_dict = config.get("spec", {}) + if not spec_dict: + continue + + connector_spec = ConnectorSpec( + name=spec_dict.get("name", "SharedMemoryConnector"), + extra=spec_dict.get("extra", {}), + ) + stage_connector_specs[(str(from_stage), str(stage_id))] = connector_spec + + try: + # Use unified connector creation logic + connectors = create_connectors_from_config(stage_connector_specs) + except Exception as exc: # pragma: no cover - defensive logging + # Fail fast so the stage does not start with missing connectors. + logger.exception("[Stage-%s] Failed to initialize connectors: %s", stage_id, exc) + raise + + return connectors + + +def resolve_omni_kv_config_for_stage( + transfer_cfg: OmniTransferConfig | None, stage_id: int | str +) -> tuple[dict[str, Any] | None, str | None, str | None]: + """Resolve connector configuration for a specific stage (Sender/Receiver). + + This determines the primary connector configuration to be injected into the + engine arguments, prioritizing outgoing edges (Sender role). + """ + if not transfer_cfg or not getattr(transfer_cfg, "connectors", None): + return None, None, None + + stage_id_str = str(stage_id) + + # Find outgoing edges (Sender logic) + outgoing = [ + (to_stage, spec) + for (from_stage, to_stage), spec in transfer_cfg.connectors.items() + if from_stage == stage_id_str + ] + + # Find incoming edges (Receiver logic) + incoming = [ + (from_stage, spec) + for (from_stage, to_stage), spec in transfer_cfg.connectors.items() + if to_stage == stage_id_str + ] + + omni_conn_cfg = None + omni_from = None + omni_to = None + + # Prioritize outgoing (Sender) if exists, else check incoming (Receiver) + if outgoing: + if len(outgoing) > 1: + logger.debug( + "Stage-%s has %d outgoing edges; using the smallest to_stage", + stage_id, + len(outgoing), + ) + outgoing.sort(key=lambda x: int(x[0]) if str(x[0]).isdigit() else str(x[0])) + to_s, spec = outgoing[0] + omni_conn_cfg = {"type": spec.name, **(spec.extra or {})} + omni_from = stage_id_str + omni_to = str(to_s) + elif incoming: + # For receiver, pick one incoming edge to configure the connector + incoming.sort(key=lambda x: int(x[0]) if str(x[0]).isdigit() else str(x[0])) + from_s, spec = incoming[0] + omni_conn_cfg = {"type": spec.name, **(spec.extra or {})} + omni_from = str(from_s) + omni_to = stage_id_str + + return omni_conn_cfg, omni_from, omni_to diff --git a/vllm_omni/distributed/omni_connectors/utils/logging.py b/vllm_omni/distributed/omni_connectors/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..a3410f53d36a18292207cf3cee9eb109109b07a4 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/utils/logging.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging + +try: + from vllm.logger import init_logger as _vllm_init_logger +except Exception: # pragma: no cover - optional dependency + _vllm_init_logger = None + + +def get_connector_logger(name: str) -> logging.Logger: + """Return a logger preferring vLLM's init_logger when available.""" + return _vllm_init_logger(name) if _vllm_init_logger else logging.getLogger(name) diff --git a/vllm_omni/distributed/omni_connectors/utils/serialization.py b/vllm_omni/distributed/omni_connectors/utils/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..74aa6d4d08e69557263f315e5f6d9560b0bb5960 --- /dev/null +++ b/vllm_omni/distributed/omni_connectors/utils/serialization.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import asdict, is_dataclass +from typing import Any + +import msgspec +import numpy as np +import torch +from msgspec import msgpack +from PIL import Image +from vllm.outputs import CompletionOutput, RequestOutput + +# Type markers for custom serialization +_TENSOR_MARKER = "__tensor__" +_NDARRAY_MARKER = "__ndarray__" +_PIL_IMAGE_MARKER = "__pil_image__" + +# Keys that identify a RequestOutput dict (for reconstruction) +_REQUEST_OUTPUT_KEYS = frozenset({"request_id", "prompt", "prompt_token_ids", "outputs", "finished"}) + +# Keys that identify a CompletionOutput dict (for reconstruction) +_COMPLETION_OUTPUT_KEYS = frozenset({"index", "text", "token_ids", "finish_reason"}) + +# Keys that identify an OmniRequestOutput dict (for reconstruction) +# OmniRequestOutput has 'final_output_type' which is unique, or can be identified by +# having 'finished' and ('images' or 'final_output_type') +_OMNI_REQUEST_OUTPUT_KEYS = frozenset({"finished", "final_output_type"}) + + +class OmniMsgpackEncoder: + """ + This implementation is adapted from vLLM’s MsgpackEncoder. + However, zero-copy support has not been implemented yet. + Handles torch.Tensor, numpy.ndarray, PIL.Image, RequestOutput and + CompletionOutput by converting them to serializable dict representations. + TODO: Enable zero-copy support. + """ + + def __init__(self): + self.encoder = msgpack.Encoder(enc_hook=self._enc_hook) + + def encode(self, obj: Any) -> bytes: + """Encode an object to bytes.""" + return self.encoder.encode(obj) + + def _enc_hook(self, obj: Any) -> Any: + """Custom encoding hook for non-standard types.""" + # torch.Tensor + if isinstance(obj, torch.Tensor): + return self._encode_tensor(obj) + + # numpy.ndarray (exclude object/void dtypes) + if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"): + return self._encode_ndarray(obj) + + # PIL.Image + if isinstance(obj, Image.Image): + return self._encode_pil_image(obj) + + # RequestOutput (not a dataclass, needs special handling) + if isinstance(obj, RequestOutput): + return self._encode_request_output(obj) + + # CompletionOutput (dataclass) + if isinstance(obj, CompletionOutput): + return self._encode_completion_output(obj) + + # Other dataclasses + if is_dataclass(obj) and not isinstance(obj, type): + return asdict(obj) + + # slice + if isinstance(obj, slice): + return (obj.start, obj.stop, obj.step) + + raise TypeError( + f"Object of type {type(obj).__name__} is not serializable. " + "Supported types: torch.Tensor, np.ndarray, PIL.Image, dataclass, " + "RequestOutput, and standard Python types (dict, list, str, int, float, bool, None, bytes)." + ) + + def _encode_tensor(self, tensor: torch.Tensor) -> dict[str, Any]: + """Encode torch.Tensor to dict.""" + t = tensor.detach().contiguous().cpu() + # Handle 0-dimensional (scalar) tensors by reshaping to 1D first + if t.dim() == 0: + t = t.reshape(1) + t = t.view(torch.uint8) + return { + _TENSOR_MARKER: True, + "dtype": str(tensor.dtype).removeprefix("torch."), + "shape": list(tensor.shape), + "data": t.numpy().tobytes(), + } + + def _encode_ndarray(self, arr: np.ndarray) -> dict[str, Any]: + """Encode numpy.ndarray to dict.""" + if not arr.flags.c_contiguous: + arr = np.ascontiguousarray(arr) + return { + _NDARRAY_MARKER: True, + "dtype": arr.dtype.str, + "shape": list(arr.shape), + "data": arr.tobytes(), + } + + def _encode_pil_image(self, img: Image.Image) -> dict[str, Any]: + """Encode PIL.Image to dict.""" + arr = np.asarray(img, dtype=np.uint8) + if not arr.flags.c_contiguous: + arr = np.ascontiguousarray(arr) + return { + _PIL_IMAGE_MARKER: True, + "mode": img.mode, + "shape": list(arr.shape), + "data": arr.tobytes(), + } + + def _encode_request_output(self, obj: RequestOutput) -> dict[str, Any]: + """Encode RequestOutput to dict. + + RequestOutput is not a dataclass, so we manually extract its attributes. + Also handles dynamically added 'multimodal_output' attribute. + """ + # msgspec can serialize CompletionOutput dataclasses directly, but it + # drops dynamic fields such as multimodal_output. Encode them manually + # to preserve multimodal payloads across IPC. + encoded_outputs = [] + for o in obj.outputs: + if isinstance(o, CompletionOutput): + encoded_outputs.append(self._encode_completion_output(o)) + else: + encoded_outputs.append(o) + + result = { + "request_id": obj.request_id, + "prompt": obj.prompt, + "prompt_token_ids": obj.prompt_token_ids, + "prompt_logprobs": obj.prompt_logprobs, + "outputs": encoded_outputs, + "finished": obj.finished, + "metrics": obj.metrics, + "lora_request": obj.lora_request, + "encoder_prompt": obj.encoder_prompt, + "encoder_prompt_token_ids": obj.encoder_prompt_token_ids, + "num_cached_tokens": obj.num_cached_tokens, + "multi_modal_placeholders": obj.multi_modal_placeholders, + "kv_transfer_params": obj.kv_transfer_params, + } + # Handle dynamically added multimodal_output attribute + mm_output = getattr(obj, "multimodal_output", None) + if mm_output is not None: + result["multimodal_output"] = mm_output + return result + + def _encode_completion_output(self, obj: CompletionOutput) -> dict[str, Any]: + """Encode CompletionOutput to dict, preserving multimodal payloads.""" + result = asdict(obj) + mm_output = getattr(obj, "multimodal_output", None) + if mm_output is not None: + result["multimodal_output"] = mm_output + return result + + +class OmniMsgpackDecoder: + """ + This implementation is adapted from vLLM’s MsgpackDecoder. + However, zero-copy support has not been implemented yet. + + Automatically reconstructs torch.Tensor, numpy.ndarray, PIL.Image, + RequestOutput and CompletionOutput from their dict representations. + TODO: Enable zero-copy support. + """ + + def __init__(self): + self.decoder = msgpack.Decoder() + + def decode(self, data: bytes | bytearray | memoryview) -> Any: + """Decode bytes to object.""" + result = self.decoder.decode(data) + return self._post_process(result) + + def _post_process(self, obj: Any) -> Any: + """Recursively restore tensor/ndarray/image/RequestOutput/OmniRequestOutput from their dict representations.""" + if isinstance(obj, dict): + # Check for type markers first + if obj.get(_TENSOR_MARKER): + return self._decode_tensor(obj) + if obj.get(_NDARRAY_MARKER): + return self._decode_ndarray(obj) + if obj.get(_PIL_IMAGE_MARKER): + return self._decode_pil_image(obj) + + # Process values recursively first + processed = {k: self._post_process(v) for k, v in obj.items()} + + # Check if this looks like an OmniRequestOutput (check before RequestOutput + # since OmniRequestOutput may also have some RequestOutput-like fields) + if self._is_omni_request_output(processed): + return self._decode_omni_request_output(processed) + + # Check if this looks like a RequestOutput + if _REQUEST_OUTPUT_KEYS.issubset(processed.keys()): + return self._decode_request_output(processed) + + # Check if this looks like a CompletionOutput + if _COMPLETION_OUTPUT_KEYS.issubset(processed.keys()): + return self._decode_completion_output(processed) + + return processed + + if isinstance(obj, list): + return [self._post_process(item) for item in obj] + + if isinstance(obj, tuple): + return tuple(self._post_process(item) for item in obj) + + return obj + + def _is_omni_request_output(self, obj: dict[str, Any]) -> bool: + """Check if a dict looks like an OmniRequestOutput. + + OmniRequestOutput can be identified by: + - Having 'finished' and 'final_output_type' fields (unique to OmniRequestOutput) + - OR having 'finished' and 'images' fields (diffusion mode) + """ + # Must have 'finished' field + if "finished" not in obj: + return False + + # Check for unique identifier: 'final_output_type' + if "final_output_type" in obj: + return True + + # Alternative: check for 'images' field (diffusion mode) + if "images" in obj: + return True + + return False + + def _decode_omni_request_output(self, obj: dict[str, Any]) -> Any: + """Decode dict to OmniRequestOutput. + + OmniRequestOutput is a dataclass, so we can use msgspec.convert + or construct it directly. + """ + from vllm_omni.outputs import OmniRequestOutput + + try: + # Use msgspec.convert for dataclass reconstruction + return msgspec.convert(obj, OmniRequestOutput) + except Exception: + try: + # Fallback: construct directly if msgspec.convert fails + # (e.g., if some fields are missing or have wrong types) + return OmniRequestOutput(**obj) + except Exception: + # If both attempts fail, return dict as-is (defensive fallback) + # This should rarely happen if _is_omni_request_output is correct + return obj + + def _decode_tensor(self, obj: dict[str, Any]) -> torch.Tensor: + """Decode dict to torch.Tensor.""" + dtype_str = obj["dtype"] + shape = obj["shape"] + data = obj["data"] + + torch_dtype = getattr(torch, dtype_str) + if not data: + return torch.empty(shape, dtype=torch_dtype) + + buffer = bytearray(data) if isinstance(data, (bytes, memoryview)) else data + arr = torch.frombuffer(buffer, dtype=torch.uint8) + return arr.view(torch_dtype).reshape(shape) + + def _decode_ndarray(self, obj: dict[str, Any]) -> np.ndarray: + """Decode dict to numpy.ndarray.""" + dtype = obj["dtype"] + shape = obj["shape"] + data = obj["data"] + return np.frombuffer(data, dtype=dtype).reshape(shape) + + def _decode_pil_image(self, obj: dict[str, Any]) -> Image.Image: + """Decode dict to PIL.Image.""" + mode = obj["mode"] + shape = obj["shape"] + data = obj["data"] + arr = np.frombuffer(data, dtype=np.uint8).reshape(shape) + return Image.fromarray(arr, mode=mode) + + def _decode_completion_output(self, obj: dict[str, Any]) -> CompletionOutput: + """Decode dict to CompletionOutput using msgspec.convert.""" + mm_output = obj.pop("multimodal_output", None) + co = msgspec.convert(obj, CompletionOutput) + if mm_output is not None: + setattr(co, "multimodal_output", mm_output) + return co + + def _decode_request_output(self, obj: dict[str, Any]) -> RequestOutput: + """Decode dict to RequestOutput. + + RequestOutput is not a dataclass, so msgspec.convert doesn't work. + We construct it manually, passing all known fields via **kwargs. + """ + # Extract multimodal_output before constructing (it's dynamically added) + mm_output = obj.pop("multimodal_output", None) + + # RequestOutput.__init__ accepts **kwargs for forward compatibility + ro = RequestOutput(**obj) + + # Restore dynamically added multimodal_output attribute + if mm_output is not None: + setattr(ro, "multimodal_output", mm_output) + return ro + + +class OmniSerde: + """Serialization/deserialization handler for Omni IPC.""" + + def __init__(self): + self.encoder = OmniMsgpackEncoder() + self.decoder = OmniMsgpackDecoder() + + def serialize(self, obj: Any) -> bytes: + """Serialize an object to bytes.""" + return self.encoder.encode(obj) + + def deserialize(self, data: bytes | bytearray | memoryview) -> Any: + """Deserialize bytes to an object.""" + return self.decoder.decode(data) + + +# Global instance for simple interface +OmniSerializer = OmniSerde() diff --git a/vllm_omni/distributed/ray_utils/__init__.py b/vllm_omni/distributed/ray_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..afdb0cec282869d57102b84a63b85dfdbbb31b2f --- /dev/null +++ b/vllm_omni/distributed/ray_utils/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .utils import calculate_total_bytes, is_ray_initialized, maybe_disable_pin_memory_for_ray + +__all__ = [ + "calculate_total_bytes", + "is_ray_initialized", + "maybe_disable_pin_memory_for_ray", +] diff --git a/vllm_omni/distributed/ray_utils/utils.py b/vllm_omni/distributed/ray_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07513b6601e8812d32e656f69e2b6c5824256aad --- /dev/null +++ b/vllm_omni/distributed/ray_utils/utils.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +import os +from contextlib import contextmanager +from typing import Any + +import torch + +try: + import ray + from ray.util.queue import Queue as RayQueue + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + + RAY_AVAILABLE = True + from ray.util.placement_group import PlacementGroup +except ImportError: + ray = None + RayQueue = None + PlacementGroupSchedulingStrategy = None + RAY_AVAILABLE = False + PlacementGroup = Any + +logger = logging.getLogger(__name__) + + +def is_ray_initialized(): + """Check if Ray is initialized without hard dependency on Ray.""" + # 1. Try standard API + if RAY_AVAILABLE: + if ray.is_initialized(): + return True + # 2. Fallback: Check environment variables typical for Ray Workers + # RAY_RAYLET_PID is always set in Ray workers + if "RAY_RAYLET_PID" in os.environ: + return True + return False + + +def calculate_total_bytes(size_args, dtype): + """ + Calculate total bytes for a tensor allocation, handling nested tuples in size args. + """ + num_elements = 1 + for s in size_args: + if isinstance(s, (tuple, list)): + for inner in s: + num_elements *= inner + else: + num_elements *= s + + element_size = torch.tensor([], dtype=dtype).element_size() + return num_elements * element_size + + +@contextmanager +def maybe_disable_pin_memory_for_ray(obj, size_bytes, threshold=32 * 1024 * 1024): + """ + Context manager to temporarily disable pin_memory if running in Ray and + the allocation size exceeds the threshold. + + This is a workaround for Ray workers often having low ulimit -l (locked memory), + causing OS call failed errors when allocating large pinned buffers. + """ + should_disable = False + old_pin = False + + # Check 1: Are we in a Ray-like environment? + in_ray = is_ray_initialized() + + # Check 2: Is the size large enough to worry? + is_large = size_bytes > threshold + + # Check 3: Is pinning currently enabled? + is_pinned = getattr(obj, "pin_memory", False) + + if in_ray and is_large and is_pinned: + should_disable = True + old_pin = obj.pin_memory + obj.pin_memory = False + + try: + yield + finally: + if should_disable: + obj.pin_memory = old_pin + + +# --- Ray specific utilities --- + + +def get_ray_queue_class(): + if not RAY_AVAILABLE: + raise ImportError("ray is required for worker_backend='ray'") + return lambda: RayQueue(maxsize=0) + + +def initialize_ray_cluster(address: str | None = None): + if not RAY_AVAILABLE: + logger.warning("Ray is not available, skipping initialization.") + return + + if not ray.is_initialized(): + # Pass current PYTHONPATH to workers to ensure they can find vllm_omni + runtime_env = {"env_vars": {"PYTHONPATH": os.environ.get("PYTHONPATH", "")}} + ray.init(address=address, ignore_reinit_error=True, runtime_env=runtime_env) + + +def create_placement_group(number_of_stages: int, address: str | None = None, strategy: str = "PACK") -> PlacementGroup: + """Create a placement group for the given number of stages. + Args: + number_of_stages: The number of stages to create the placement group for. + strategy: The strategy to use for the placement group. + Returns: + The placement group. + """ + if not RAY_AVAILABLE: + raise ImportError("ray is required for creating placement group") + + # Initialize Ray if not already initialized (using default args if needed) + if not ray.is_initialized(): + logger.warning("[Orchestrator] Ray is not initialized. Initializing with default settings.") + initialize_ray_cluster(address) + + bundles = [{"GPU": 1.0, "CPU": 1.0} for _ in range(number_of_stages)] + pg = ray.util.placement_group(bundles, strategy=strategy) + ray.get(pg.ready()) + logger.info("[Orchestrator] Ray Placement Group created") + return pg + + +def remove_placement_group(pg): + if pg and RAY_AVAILABLE: + try: + ray.util.remove_placement_group(pg) + except Exception as e: + logger.warning(f"Failed to remove placement group: {e}") + + +def try_close_ray(pg=None): + """Try to clean up Ray resources including placement group and shutdown.""" + if pg: + remove_placement_group(pg) + # Note: We typically don't shutdown ray.init() here as it might be used by other components + # or the user might want it to persist. If full shutdown is needed, ray.shutdown() can be called. + + +def kill_ray_actor(actor): + if actor and RAY_AVAILABLE: + try: + ray.kill(actor) + except Exception as e: + logger.warning(f"Failed to kill actor: {e}") + + +def start_ray_actor( + worker_entry_fn, + placement_group, + placement_group_bundle_index: int, + *args, + **kwargs, +): + if not RAY_AVAILABLE: + raise ImportError("ray is required for starting ray actor") + + @ray.remote(num_gpus=1) + class OmniStageRayWorker: + def run(self, func, *args, **kwargs): + return func(*args, **kwargs) + + worker_actor = OmniStageRayWorker.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_index + ), + runtime_env={"env_vars": {"PYTHONPATH": os.environ.get("PYTHONPATH", "")}, "CUDA_LAUNCH_BLOCKING": "1"}, + ).remote() + + worker_actor.run.remote(worker_entry_fn, *args, **kwargs) + + return worker_actor diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0adfdb8d1930c098c1f2642acf7fbeb3b3ecf75a --- /dev/null +++ b/vllm_omni/engine/__init__.py @@ -0,0 +1,81 @@ +""" +Engine components for vLLM-Omni. +""" + +from typing import Any + +import msgspec +import torch +from vllm.v1.engine import ( + EngineCoreOutput, + EngineCoreOutputs, + EngineCoreRequest, +) + + +class PromptEmbedsPayload(msgspec.Struct): + """Serialized prompt embeddings payload for direct transfer. + + data: raw bytes of the tensor in row-major order + shape: [seq_len, hidden_size] + dtype: torch dtype name (e.g., "float16", "float32") + """ + + data: bytes + shape: list[int] + dtype: str + + +class AdditionalInformationEntry(msgspec.Struct): + """One entry of additional_information. + + Two supported forms are encoded: + - tensor: data/shape/dtype + - list: a Python list (msgspec-serializable) + Exactly one of (tensor_data, list_data) should be non-None. + """ + + # Tensor form + tensor_data: bytes | None = None + tensor_shape: list[int] | None = None + tensor_dtype: str | None = None + + # List form + list_data: list[Any] | None = None + + +class AdditionalInformationPayload(msgspec.Struct): + """Serialized dictionary payload for additional_information. + + Keys are strings; values are encoded as AdditionalInformationEntry. + """ + + entries: dict[str, AdditionalInformationEntry] + + +class OmniEngineCoreRequest(EngineCoreRequest): + """Engine core request for omni models with embeddings support. + + Extends the base EngineCoreRequest with support for prompt embeddings + and additional information payloads, enabling direct transfer of + pre-computed embeddings between pipeline stages. + + Attributes: + prompt_embeds: Optional serialized prompt embeddings payload for + direct transfer between stages + additional_information: Optional serialized additional information + dictionary containing tensors or lists to pass along with the request + """ + + # Optional prompt embeddings (direct-transfer version) + prompt_embeds: PromptEmbedsPayload | None = None + # Optional additional information dictionary (serialized) + additional_information: AdditionalInformationPayload | None = None + + +class OmniEngineCoreOutput(EngineCoreOutput): + pooling_output: dict[str, torch.Tensor] | None = None + + +class OmniEngineCoreOutputs(EngineCoreOutputs): + outputs: list[OmniEngineCoreOutput] = [] diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d550cf55548200843e56e1fe66e86dd2c3447df7 --- /dev/null +++ b/vllm_omni/engine/arg_utils.py @@ -0,0 +1,241 @@ +from dataclasses import dataclass, field +from typing import Any + +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeTextConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_hf_text_config +from vllm.v1.engine.async_llm import AsyncEngineArgs + +from vllm_omni.config import OmniModelConfig +from vllm_omni.plugins import load_omni_general_plugins + +logger = init_logger(__name__) + + +def _register_omni_hf_configs() -> None: + try: + from transformers import AutoConfig + + from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import ( + Qwen3TTSConfig, + ) + except Exception as exc: # pragma: no cover - best-effort optional registration + logger.warning("Skipping omni HF config registration due to import error: %s", exc) + return + + try: + AutoConfig.register("qwen3_tts", Qwen3TTSConfig) + except ValueError: + # Already registered elsewhere; ignore. + return + + +def register_omni_models_to_vllm(): + from vllm.model_executor.models import ModelRegistry + + from vllm_omni.model_executor.models.registry import _OMNI_MODELS + + _register_omni_hf_configs() + + supported_archs = ModelRegistry.get_supported_archs() + for arch, (mod_folder, mod_relname, cls_name) in _OMNI_MODELS.items(): + if arch not in supported_archs: + ModelRegistry.register_model(arch, f"vllm_omni.model_executor.models.{mod_folder}.{mod_relname}:{cls_name}") + + +@dataclass +class OmniEngineArgs(EngineArgs): + """Engine arguments for omni models, extending base EngineArgs. + Adds omni-specific configuration fields for multi-stage pipeline + processing and output type specification. + Args: + stage_id: Identifier for the stage in a multi-stage pipeline (default: 0) + model_stage: Stage type identifier, e.g., "thinker" or "talker" + (default: "thinker") + model_arch: Model architecture name + (default: "Qwen2_5OmniForConditionalGeneration") + engine_output_type: Optional output type specification for the engine. + Used to route outputs to appropriate processors (e.g., "image", + "audio", "latents"). If None, output type is inferred. + custom_process_next_stage_input_func: Optional path to a custom function for processing + inputs from previous stages + If None, default processing is used. + stage_connector_spec: Extra configuration for stage connector + async_chunk: If set to True, perform async chunk + """ + + stage_id: int = 0 + model_stage: str = "thinker" + model_arch: str = "Qwen2_5OmniForConditionalGeneration" + engine_output_type: str | None = None + hf_config_name: str | None = None + custom_process_next_stage_input_func: str | None = None + stage_connector_spec: dict[str, Any] = field(default_factory=dict) + async_chunk: bool = False + omni_kv_config: dict | None = None + + def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: + # transformers' get_text_config method is used to get the text config from thinker_config. + # to handle the case that each model stage has their own text config, + # we need to draw the text config from the corresponding model stage. + hf_config = config_dict["hf_config"] + hf_config_name = config_dict["hf_config_name"] + try: + # Try to get the stage-specific config (e.g., thinker_config, talker_config) + stage_config = getattr(hf_config, hf_config_name) + return stage_config.get_text_config() + except AttributeError: + # Fallback: if the attribute doesn't exist, use the default get_hf_text_config + logger.warning( + f"Config attribute '{hf_config_name}' not found in hf_config, " + "falling back to default get_hf_text_config" + ) + return get_hf_text_config(hf_config) + + def __post_init__(self) -> None: + load_omni_general_plugins() + super().__post_init__() + + def _ensure_omni_models_registered(self): + if hasattr(self, "_omni_models_registered"): + return True + register_omni_models_to_vllm() + self._omni_models_registered = True + return True + + def create_model_config(self) -> OmniModelConfig: + """Create an OmniModelConfig from these engine arguments. + Returns: + OmniModelConfig instance with all configuration fields set + """ + # register omni models to avoid model not found error + self._ensure_omni_models_registered() + + # First, get the base ModelConfig from the parent class + base_config = super().create_model_config() + + # Create OmniModelConfig by copying all base config attributes + # and adding the new omni-specific fields + config_dict = base_config.__dict__.copy() + # FIXME(Isotr0py): This is a temporary workaround for multimodal_config + config_dict = { + **(getattr(mm := config_dict.pop("multimodal_config", None), "__dict__", mm or {})), + **config_dict, + } + + # Add the new omni-specific fields + config_dict["stage_id"] = self.stage_id + config_dict["async_chunk"] = self.async_chunk + config_dict["model_stage"] = self.model_stage + config_dict["model_arch"] = self.model_arch + config_dict["engine_output_type"] = self.engine_output_type + # Build stage_connector_config from stage_connector_spec + stage_connector_config = { + "name": self.stage_connector_spec.get("name", "SharedMemoryConnector"), + "extra": self.stage_connector_spec.get("extra", {}).copy(), + } + stage_connector_config["extra"]["stage_id"] = self.stage_id + config_dict["stage_connector_config"] = stage_connector_config + + config_dict["hf_config_name"] = self.hf_config_name + config_dict["custom_process_next_stage_input_func"] = self.custom_process_next_stage_input_func + config_dict["omni_kv_config"] = self.omni_kv_config + if self.hf_config_name is not None: + config_dict["hf_text_config"] = self.draw_hf_text_config(config_dict) + # Create and return the OmniModelConfig instance + omni_config = OmniModelConfig(**config_dict) + omni_config.hf_config.architectures = omni_config.architectures + + return omni_config + + +@dataclass +class AsyncOmniEngineArgs(AsyncEngineArgs): + """Async engine arguments for omni models, extending base AsyncEngineArgs. + Adds omni-specific configuration fields for multi-stage pipeline + processing and output type specification in async contexts. + Args: + stage_id: Identifier for the stage in a multi-stage pipeline (default: 0) + model_stage: Stage type identifier, e.g., "thinker" or "talker" + (default: "thinker") + model_arch: Model architecture name + (default: "Qwen2_5OmniForConditionalGeneration") + engine_output_type: Optional output type specification for the engine. + Used to route outputs to appropriate processors (e.g., "image", + "audio", "latents"). If None, output type is inferred. + stage_connector_spec: Extra configuration for stage connector + """ + + stage_id: int = 0 + model_stage: str = "thinker" + model_arch: str = "Qwen2_5OmniForConditionalGeneration" + engine_output_type: str | None = None + hf_config_name: str | None = None + custom_process_next_stage_input_func: str | None = None + stage_connector_spec: dict[str, Any] = field(default_factory=dict) + async_chunk: bool = False + omni_kv_config: dict | None = None + + def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: + # transformers' get_text_config method is used to get the text config from thinker_config. + # to handle the case that each model stage has their own text config, + # we need to draw the text config from the corresponding model stage. + hf_config = config_dict["hf_config"] + hf_config_name = config_dict["hf_config_name"] + try: + # Try to get the stage-specific config (e.g., thinker_config, talker_config) + stage_config = getattr(hf_config, hf_config_name) + return stage_config.get_text_config() + except AttributeError: + # Fallback: if the attribute doesn't exist, use the default get_hf_text_config + logger.warning( + f"Config attribute '{hf_config_name}' not found in hf_config, " + "falling back to default get_hf_text_config" + ) + return get_hf_text_config(hf_config) + + def __post_init__(self) -> None: + load_omni_general_plugins() + super().__post_init__() + + def _ensure_omni_models_registered(self): + if hasattr(self, "_omni_models_registered"): + return True + register_omni_models_to_vllm() + self._omni_models_registered = True + return True + + def create_model_config(self) -> OmniModelConfig: + # register omni models to avoid model not found error + self._ensure_omni_models_registered() + # First, get the base ModelConfig from the parent class + base_config = super().create_model_config() + + # Create OmniModelConfig by copying all base config attributes + # and adding the new omni-specific fields + config_dict = base_config.__dict__.copy() + + # Add the new omni-specific fields + config_dict["stage_id"] = self.stage_id + config_dict["async_chunk"] = self.async_chunk + config_dict["model_stage"] = self.model_stage + config_dict["model_arch"] = self.model_arch + config_dict["engine_output_type"] = self.engine_output_type + stage_connector_config = { + "name": self.stage_connector_spec.get("name", "SharedMemoryConnector"), + "extra": self.stage_connector_spec.get("extra", {}).copy(), + } + stage_connector_config["extra"]["stage_id"] = self.stage_id + config_dict["stage_connector_config"] = stage_connector_config + + config_dict["hf_config_name"] = self.hf_config_name + config_dict["custom_process_next_stage_input_func"] = self.custom_process_next_stage_input_func + config_dict["omni_kv_config"] = self.omni_kv_config + if self.hf_config_name is not None: + config_dict["hf_text_config"] = self.draw_hf_text_config(config_dict) + # Create and return the OmniModelConfig instance + omni_config = OmniModelConfig(**config_dict) + omni_config.hf_config.architectures = omni_config.architectures + + return omni_config diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e467a88ace3d59ece82f981ce325fc4866dc5a --- /dev/null +++ b/vllm_omni/engine/input_processor.py @@ -0,0 +1,296 @@ +import os +import time +from collections.abc import Mapping +from typing import Any, cast + +import numpy as np +import torch +from vllm.config import VllmConfig +from vllm.inputs import ProcessorInputs, PromptType +from vllm.inputs.parse import split_enc_dec_inputs +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict +from vllm.multimodal.processing.context import set_request_id +from vllm.multimodal.utils import argsort_mm_positions +from vllm.platforms import current_platform +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.v1.engine.input_processor import InputProcessor + +from vllm_omni.engine import ( + AdditionalInformationEntry, + AdditionalInformationPayload, + OmniEngineCoreRequest, + PromptEmbedsPayload, +) +from vllm_omni.inputs.preprocess import OmniInputPreprocessor +from vllm_omni.lora.request import LoRARequest + +logger = init_logger(__name__) + + +class OmniInputProcessor(InputProcessor): + """Processor for omni models, handling multimodal inputs and embeddings. + + Extends the base vLLM Processor with support for processing prompt + embeddings and additional information payloads, enabling direct transfer + of pre-computed embeddings between pipeline stages. + + Args: + vllm_config: Global vLLM configuration + mm_registry: Multi-modal registry for processing multimodal inputs + """ + + @staticmethod + def _dtype_to_name(dtype: torch.dtype) -> str: + """Convert torch dtype to string representation. + + Args: + dtype: PyTorch dtype to convert + + Returns: + String representation of the dtype (e.g., "float32", "int64") + """ + mapping = { + torch.float32: "float32", + torch.float: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.bfloat16: "bfloat16", + torch.float64: "float64", + torch.double: "float64", + torch.int64: "int64", + torch.long: "int64", + torch.int32: "int32", + torch.int: "int32", + torch.int16: "int16", + torch.short: "int16", + torch.int8: "int8", + torch.uint8: "uint8", + torch.bool: "bool", + } + return mapping.get(dtype, str(dtype).replace("torch.", "")) + + def __init__( + self, + vllm_config: VllmConfig, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + super().__init__(vllm_config, mm_registry) + self.input_preprocessor = OmniInputPreprocessor( + self.model_config, + vllm_config.observability_config, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) + + def process_inputs( + self, + request_id: str, + prompt: PromptType, + params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, + priority: int = 0, + data_parallel_rank: int | None = None, + resumable: bool = False, + ) -> OmniEngineCoreRequest: + """Process input prompt into an engine core request. + + Converts a prompt (text, tokens, or multimodal) into an + OmniEngineCoreRequest that can be processed by the engine. + Handles prompt embeddings and additional information payloads + for direct transfer between stages. + + Args: + request_id: Unique identifier for this request + prompt: Input prompt (text, token IDs, embeddings, or multimodal) + params: Sampling or pooling parameters for generation + arrival_time: Optional arrival timestamp (defaults to current time) + lora_request: Optional LoRA adapter request + tokenization_kwargs: Optional additional tokenization arguments + trace_headers: Optional tracing headers for observability + priority: Request priority (higher values processed first) + data_parallel_rank: Optional data parallel rank for distributed + inference + + Returns: + Tuple of (prompt_string, OmniEngineCoreRequest) where: + - prompt_string: The original prompt as a string, or None if + using embeddings + - OmniEngineCoreRequest: Processed request ready for the engine + + Raises: + ValueError: If data_parallel_rank is out of range or prompt_embeds + has incorrect shape + """ + self._validate_lora(lora_request) + self._validate_params(params) + + parallel_config = self.vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + dp_local_size = parallel_config.data_parallel_size_local + num_ranks = dp_local_size if parallel_config.local_engines_only else dp_size + if data_parallel_rank is not None and not (0 <= data_parallel_rank < num_ranks): + raise ValueError(f"data_parallel_rank {data_parallel_rank} is out of range [0, {num_ranks}).") + + if arrival_time is None: + arrival_time = time.time() + + # Optionally generate multimodal hash overrides to avoid hashing + # multimodal data items by their content as their identifiers. + + # NOTE: when users explicitly turn off BOTH prefix caching and input + # processing caching, no multimodal features or embeddings will be + # reused across requests, therefore identifying multimodal data items + # by their content is no longer necessary, and we create uuids with + # request id-modality-index as multimodal hash overrides. + if ( + self.model_config.multimodal_config + and self.model_config.multimodal_config.mm_processor_cache_gb == 0 + and not self.cache_config.enable_prefix_caching + ): + mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) + else: + # Otherwise, use user-provided uuids as multimodal hash overrides + # if provided. + self._validate_mm_uuids(prompt) + if isinstance(prompt, dict): + mm_uuids = cast(MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")) + else: + mm_uuids = None + + # Process inputs, which includes: + # 1. Tokenize text prompt, with LoRA request if one exists. + # 2. For multimodal models with a merged preprocessor, preprocess + # multimodal data and expand prompt token ids accordingly. + num_threads = int(os.environ.get("OMP_NUM_THREADS", "1")) + if "OMP_NUM_THREADS" not in os.environ: + logger.debug_once( + "OMP_NUM_THREADS is not set; defaulting Torch threads to %d for input preprocessing.", + num_threads, + ) + + with set_request_id(request_id), set_default_torch_num_threads(num_threads): + processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( + prompt, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) + + eos_token_id = self.input_preprocessor.get_eos_token_id() + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + self._validate_model_inputs(encoder_inputs, decoder_inputs) + + # Normalize decoder prompt access across TypedDict variants. + if decoder_inputs["type"] == "embeds": + prompt_token_ids = None + prompt_embeds = decoder_inputs["prompt_embeds"] + else: + prompt_token_ids = decoder_inputs["prompt_token_ids"] + prompt_embeds = decoder_inputs.get("prompt_embeds") + + sampling_params = None + pooling_params = None + if isinstance(params, SamplingParams): + # TODO: can we avoid cloning here in multiproc case? + sampling_params = params.clone() + # If unset max tokens, then generate up to the max_model_len. + if sampling_params.max_tokens is None: + seq_len = length_from_prompt_token_ids_or_embeds(prompt_token_ids, prompt_embeds) + sampling_params.max_tokens = self.model_config.max_model_len - seq_len + sampling_params.update_from_generation_config(self.generation_config_fields, eos_token_id) + if self.tokenizer is not None: + sampling_params.update_from_tokenizer(self.tokenizer) + else: + pooling_params = params.clone() + + # Multimodal related. + mm_features: list[MultiModalFeatureSpec] | None = None + + if decoder_inputs["type"] == "multimodal": + decoder_mm_inputs = decoder_inputs["mm_kwargs"] + decoder_mm_positions = decoder_inputs["mm_placeholders"] + decoder_mm_hashes = decoder_inputs["mm_hashes"] + + # Merge and flatten multimodal placeholders, hashes and inputs + # from dictionaries to lists, and sort them by each item's position + # in the input sequence. + sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) + + mm_features = [] + for modality, idx in sorted_mm_idxs: + base_mm_hash = decoder_mm_hashes[modality][idx] + mm_features.append( + MultiModalFeatureSpec( + data=decoder_mm_inputs[modality][idx], + modality=modality, + identifier=self._get_mm_identifier(base_mm_hash, lora_request), + mm_position=decoder_mm_positions[modality][idx], + mm_hash=base_mm_hash, + ) + ) + + # Compatibility: decode serialized prompt embeds if provided. + if isinstance(prompt_embeds, PromptEmbedsPayload): + prompt_embeds = self._decode_prompt_embeds(prompt_embeds) + + additional_information_payload: AdditionalInformationPayload | None = None + raw_info: dict[str, Any] | AdditionalInformationPayload | None = decoder_inputs.get("additional_information") + if isinstance(raw_info, AdditionalInformationPayload): + additional_information_payload = raw_info + elif raw_info is not None: + entries: dict[str, AdditionalInformationEntry] = {} + for key, value in raw_info.items(): + if isinstance(value, torch.Tensor): + v_cpu = value.detach().to("cpu").contiguous() + dtype_str = self._dtype_to_name(v_cpu.dtype) + data_bytes = v_cpu.numpy().tobytes() + entry = AdditionalInformationEntry( + tensor_data=data_bytes, + tensor_shape=[int(x) for x in list(v_cpu.shape)], + tensor_dtype=dtype_str, + ) + elif isinstance(value, list): + entry = AdditionalInformationEntry(list_data=value) + else: + raise ValueError("additional_information values must be Tensor or list") + entries[key] = entry + additional_information_payload = AdditionalInformationPayload(entries=entries) + + return OmniEngineCoreRequest( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features, + sampling_params=sampling_params, + pooling_params=pooling_params, + eos_token_id=eos_token_id, + arrival_time=arrival_time, + lora_request=lora_request, + cache_salt=decoder_inputs.get("cache_salt"), + priority=priority, + data_parallel_rank=data_parallel_rank, + trace_headers=trace_headers, + prompt_embeds=prompt_embeds, + additional_information=additional_information_payload, + resumable=resumable, + ) + + @staticmethod + def _decode_prompt_embeds(payload: PromptEmbedsPayload) -> torch.Tensor: + dtype = getattr(np, payload.dtype) + arr = np.frombuffer(payload.data, dtype=dtype) + arr = arr.reshape(payload.shape) + return torch.from_numpy(arr) diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..e39dc9fec2502a8f68120e638f7352a557b72057 --- /dev/null +++ b/vllm_omni/engine/output_processor.py @@ -0,0 +1,466 @@ +from ast import Dict +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput +from vllm.sampling_params import RequestOutputKind +from vllm.tokenizers import TokenizerLike +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine.output_processor import OutputProcessor as VLLMOutputProcessor +from vllm.v1.engine.output_processor import ( + OutputProcessorOutput, + RequestOutputCollector, + RequestState, +) +from vllm.v1.engine.parallel_sampling import ParentRequest +from vllm.v1.metrics.stats import IterationStats + +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + + +class OmniRequestState(RequestState): + """Request state for omni models, tracking multimodal outputs. + + Extends the base RequestState with support for accumulating + multimodal tensor outputs (e.g., images, audio, latents) that + are produced incrementally during generation. + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.mm_type: str | None = None + self.mm_accumulated: Dict[str, Any] | None = None + + def add_multimodal_tensor(self, payload: Any | None, mm_type: str | None) -> None: + if payload is None: + return + try: + if mm_type: + self.mm_type = (mm_type or "").lower() + + # Normalize incoming payload to dict on CPU + def _to_cpu(x): + if isinstance(x, torch.Tensor): + try: + return x.detach().to("cpu", non_blocking=True).contiguous() + except Exception: + return x + return x + + if isinstance(payload, dict): + incoming: Dict[str, Any] = {} + target_key = self.mm_type or "hidden" + + # Iterate directly without unnecessary dict copy + for k, v in payload.items(): + # Optional remap: if producer used "model_outputs" or "hidden", rename to mm_type + if k == "model_outputs": + k = target_key + elif k == "hidden" and target_key != "hidden": + k = target_key + + if isinstance(v, dict): + incoming[k] = {str(sk): _to_cpu(sv) for sk, sv in v.items()} + else: + incoming[k] = _to_cpu(v) + else: + key = self.mm_type or "hidden" + incoming = {key: _to_cpu(payload)} + + if self.mm_accumulated is None: + self.mm_accumulated = incoming + else: + # Merge keys; accumulate tensors in lists for deferred concatenation + for k, v in incoming.items(): + if k not in self.mm_accumulated: + self.mm_accumulated[k] = v + else: + existing = self.mm_accumulated[k] + if isinstance(v, torch.Tensor) and isinstance(existing, torch.Tensor): + # Use list accumulation to avoid O(n²) repeated concatenation + self.mm_accumulated[k] = [existing, v] + elif isinstance(v, torch.Tensor) and isinstance(existing, list): + # Append to existing list + existing.append(v) + elif isinstance(v, dict) and isinstance(existing, dict): + # Merge nested dicts with list accumulation for tensors + for sk, sv in v.items(): + if sk not in existing: + existing[sk] = sv + elif isinstance(sv, torch.Tensor) and isinstance(existing[sk], torch.Tensor): + existing[sk] = [existing[sk], sv] + elif isinstance(sv, torch.Tensor) and isinstance(existing[sk], list): + existing[sk].append(sv) + else: + existing[sk] = sv + else: + self.mm_accumulated[k] = v + except Exception: + # Log and continue without crashing the output pipeline + logger.exception("Error accumulating multimodal tensor") + + def _consolidate_multimodal_tensors(self) -> None: + """Consolidate accumulated tensor lists into single tensors via concatenation.""" + if self.mm_accumulated is None: + return + try: + for k, v in self.mm_accumulated.items(): + if isinstance(v, list) and v and isinstance(v[0], torch.Tensor): + try: + if k == "audio": + # When the audio tensor shape is inconsistent, torch.cat will fail. + # We need to use torch.cat in -1 dimension. + continue + else: + self.mm_accumulated[k] = torch.cat(v, dim=0) + except Exception: + # Keep last tensor on failure + logger.warning(f"Error concatenating tensor for key {k}; keeping last tensor") + self.mm_accumulated[k] = v[-1] + elif isinstance(v, dict): + for sk, sv in v.items(): + if isinstance(sv, list) and sv and isinstance(sv[0], torch.Tensor): + try: + v[sk] = torch.cat(sv, dim=0) + except Exception: + v[sk] = sv[-1] + except Exception: + logger.exception("Error consolidating multimodal tensors") + + # Override: do not route to pooling-only path; always create completion + # outputs, and attach pooling_result into the CompletionOutput. + def make_request_output( + self, + new_token_ids: list[int], + pooling_output: torch.Tensor | None, + finish_reason: FinishReason | None, + stop_reason: int | str | None, + kv_transfer_params: dict[str, Any] | None = None, + routed_experts: np.ndarray | None = None, + ) -> OmniRequestOutput | PoolingRequestOutput | None: + """Create a request output from generation results. + + Creates a RequestOutput or PoolingRequestOutput from the generated + tokens and accumulated multimodal outputs. Attaches multimodal + tensors to the completion output if available. + + Args: + new_token_ids: List of newly generated token IDs + pooling_output: Optional pooling output tensor + finish_reason: Optional finish reason indicating why generation stopped + stop_reason: Optional stop reason (token ID or stop string) + kv_transfer_params: Optional KV cache transfer parameters + + Returns: + OmniRequestOutput or PoolingRequestOutput if output should be + emitted (based on finish status and output kind), None otherwise + """ + # Pooling-only requests should follow base behavior. + if self.detokenizer is None and pooling_output is not None: + return super().make_request_output( + new_token_ids, + pooling_output, + finish_reason, + stop_reason, + kv_transfer_params, + routed_experts, + ) + + finished = finish_reason is not None + final_only = self.output_kind == RequestOutputKind.FINAL_ONLY + + if not finished and final_only: + return None + + # Consolidate accumulated tensors when finishing. + if finished: + self._consolidate_multimodal_tensors() + + if self.stream_interval > 1: + assert self.detokenizer is not None + + # Send output request only when + # 1. It has finished, or + # 2. It is the first token, or + # 3. It has reached the stream interval number of tokens + if not ( + finished + or self.sent_tokens_offset == 0 + or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset >= self.stream_interval + ): + return None + + if self.output_kind == RequestOutputKind.DELTA: + # Send tokens from the offset in DELTA mode, otherwise all + # tokens are sent. + new_token_ids = self.detokenizer.output_token_ids[self.sent_tokens_offset :] + self.sent_tokens_offset = len(self.detokenizer.output_token_ids) + + request_id = self.request_id + output = self._new_completion_output(new_token_ids, finish_reason, stop_reason, routed_experts) + + if self.parent_req is None: + outputs = [output] + else: + request_id, outputs, finished = self.parent_req.get_outputs(request_id, output) + if not outputs: + return None + + return self._new_request_output(request_id, outputs, finished, kv_transfer_params) + + def _new_completion_output( + self, + token_ids: list[int], + finish_reason: FinishReason | None, + stop_reason: int | str | None, + routed_experts: np.ndarray | None = None, + ) -> Any: + # Reuse base text/logprobs logic, then annotate with pooling_result. + base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason, routed_experts) + try: + if self.mm_accumulated is not None: + # Attach accumulated multimodal dict on the completion output + if not hasattr(base_output, "multimodal_output"): + setattr(base_output, "multimodal_output", {}) + mm_out = getattr(base_output, "multimodal_output") + if isinstance(mm_out, dict): + for k, v in self.mm_accumulated.items(): + mm_out[k] = v + else: + setattr(base_output, "multimodal_output", self.mm_accumulated) + except Exception: + logger.exception("Error in _new_completion_output") + return base_output + + +class MultimodalOutputProcessor(VLLMOutputProcessor): + """Handles multimodal output processing by normalizing EngineCoreOutput + before delegating to the base vLLM OutputProcessor. + + Strategy: + - Route by EngineCoreOutput.output_type when present + ("image", "text+image", "latents", "text"). + - Fallback to pooling/text heuristics when output_type is absent. + - Mutate EngineCoreOutput in-place to ensure vLLM's base processor can + produce the correct RequestOutput/PoolingRequestOutput. + - Allow custom per-modality handlers via register_handler(). + """ + + def __init__( + self, + tokenizer: TokenizerLike, + log_stats: bool, + engine_core_output_type: str | None = None, + ): + """Initialize the multimodal output processor. + + Args: + tokenizer: Tokenizer for detokenizing text outputs + log_stats: Whether to log statistics + engine_core_output_type: Optional output type specification + (e.g., "image", "audio", "latents"). Used to route outputs + to appropriate processors. If None, output type is inferred. + """ + super().__init__(tokenizer=tokenizer, log_stats=log_stats) + self.output_handlers: dict[str, Callable[[EngineCoreOutput], None]] = {} + self._reqid_to_mm_type: dict[str, str] = {} + self.engine_core_output_type = engine_core_output_type + + def register_handler(self, modality: str, handler: Callable[[EngineCoreOutput], None]) -> None: + """Register a custom handler for a specific modality. + + Allows custom processing logic for specific output modalities. + The handler is called before default processing for outputs + matching the specified modality. + + Args: + modality: Modality name (e.g., "image", "audio", "latents") + handler: Callable that takes an EngineCoreOutput and processes it + """ + self.output_handlers[modality.lower()] = handler + + def add_request( + self, + request: EngineCoreRequest, + prompt: str | None, + parent_req: ParentRequest | None = None, + request_index: int = 0, + queue: RequestOutputCollector | None = None, + ) -> None: + """Add a new request to be processed. + + Creates an OmniRequestState for the request and registers it + for output processing. + + Args: + request: Engine core request to add + prompt: Optional prompt string for the request + parent_req: Optional parent request for parallel sampling + request_index: Index of the request in the batch + queue: Optional queue for collecting outputs + + Raises: + ValueError: If the request ID is already registered + """ + request_id = request.request_id + req_state = self.request_states.get(request_id) + if req_state is not None: + self._update_streaming_request_state(req_state, request, prompt) + return + + req_state = OmniRequestState.from_new_request( + tokenizer=self.tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats, + stream_interval=self.stream_interval, + ) + if self._requests_drained.is_set(): + self._requests_drained.clear() + self.request_states[request_id] = req_state + if parent_req: + self.parent_requests[parent_req.request_id] = parent_req + self.external_req_ids[req_state.external_req_id].append(request_id) + + def process_outputs( + self, + engine_core_outputs: list[EngineCoreOutput], + engine_core_timestamp: float | None = None, + iteration_stats: IterationStats | None = None, + ) -> OutputProcessorOutput: + self._reqid_to_mm_type.clear() + for eco in engine_core_outputs: + mm_type = (self.engine_core_output_type or "").lower() + if mm_type: + self._reqid_to_mm_type[eco.request_id] = mm_type + self._route_and_normalize(eco) + req_state = self.request_states.get(eco.request_id) + if req_state is None or not isinstance(req_state, OmniRequestState): + continue + if eco.pooling_output is not None and req_state.detokenizer is not None: + req_state.add_multimodal_tensor( + eco.pooling_output, + (getattr(eco, "output_type", self.engine_core_output_type) or "").lower(), + ) + # Force text path in base processor for multimodal outputs. + eco.pooling_output = None + + return super().process_outputs( + engine_core_outputs, + engine_core_timestamp=engine_core_timestamp, + iteration_stats=iteration_stats, + ) + + # ---- routing helpers ---- + def _route_and_normalize(self, eco: EngineCoreOutput) -> None: + output_type = (getattr(eco, "output_type", self.engine_core_output_type) or "").lower() + + # Custom handler first (if registered) + if output_type in self.output_handlers: + try: + self.output_handlers[output_type](eco) + # Fall through to default fixups in case the handler left gaps + except Exception: + logger.exception("Error in custom output handler for %s", output_type) + + if output_type == "image": + self._process_image_output(eco) + elif output_type in ("text+image", "text,image", "image+text"): + self._process_text_image_output(eco) + elif output_type in ("latents", "latent"): + self._process_latents_output(eco) + elif output_type in ("audio", "speech"): + self._process_audio_output(eco) + elif output_type == "text": + self._process_text_output(eco) + else: + # Fallback heuristic + if eco.pooling_output is not None: + self._process_pooling_output(eco) + else: + self._process_text_output(eco) + + # ---- modality processors ---- + def _process_image_output(self, eco: EngineCoreOutput) -> None: + """Ensure image tensors are surfaced via pooling_output for vLLM.""" + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs(eco, keys=("image", "images", "pixel_values", "pixels")) + if tensor is not None: + eco.pooling_output = tensor + + def _process_text_image_output(self, eco: EngineCoreOutput) -> None: + """Allow text+image outputs. Text path stays as new_token_ids; + image/latents route via pooling_output.""" + # Preserve text tokens as-is; ensure pooling_output carries image/latents + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs( + eco, + keys=( + "image", + "images", + "pixel_values", + "pixels", + "latent", + "latents", + "z", + ), + ) + if tensor is not None: + eco.pooling_output = tensor + + def _process_latents_output(self, eco: EngineCoreOutput) -> None: + """Ensure latent tensors are surfaced via pooling_output.""" + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs(eco, keys=("latent", "latents", "z", "posterior")) + if tensor is not None: + eco.pooling_output = tensor + + def _process_audio_output(self, eco: EngineCoreOutput) -> None: + """Ensure audio tensors are surfaced via pooling_output.""" + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs( + eco, keys=("audio", "audios", "wav", "waveform", "audio_pcm", "pcm") + ) + if tensor is not None: + eco.pooling_output = tensor + + def _process_text_output(self, eco: EngineCoreOutput) -> None: + """No-op; base processor will detokenize new_token_ids → text.""" + return + + def _process_pooling_output(self, eco: EngineCoreOutput) -> None: + """Optional sanity checks for pooling tensor.""" + if eco.pooling_output is None: + return + if not isinstance(eco.pooling_output, torch.Tensor): + # Best-effort: convert to tensor if it's a list/ndarray-like + try: + eco.pooling_output = torch.as_tensor(eco.pooling_output) + except Exception: + pass + + def _extract_from_multimodal_outputs(self, eco: EngineCoreOutput, keys: tuple[str, ...]) -> torch.Tensor | None: + mm = getattr(eco, "multimodal_outputs", None) + if not isinstance(mm, dict): + return None + for k in keys: + v = mm.get(k) + if isinstance(v, torch.Tensor): + return v + # Try the first tensor in the dict as a fallback + for v in mm.values(): + if isinstance(v, torch.Tensor): + return v + return None diff --git a/vllm_omni/entrypoints/__init__.py b/vllm_omni/entrypoints/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0ee51a5185bc02b33c97152f415b3e620e33b8 --- /dev/null +++ b/vllm_omni/entrypoints/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project + +""" +vLLM-Omni entrypoints module. + +Provides high-level interfaces for running omni models including: +- AsyncOmni: Async orchestrator for multi-stage LLM pipelines +- AsyncOmniDiffusion: Async interface for diffusion model inference +- Omni: Unified entrypoint that auto-selects between LLM and Diffusion +""" + +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion +from vllm_omni.entrypoints.omni import Omni + +__all__ = [ + "AsyncOmni", + "AsyncOmniDiffusion", + "Omni", +] diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..144c73f3a23d1befa3388d96705177eec25ee15a --- /dev/null +++ b/vllm_omni/entrypoints/async_omni.py @@ -0,0 +1,800 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import copy +import time +import weakref +from collections.abc import AsyncGenerator, Iterable, Sequence +from dataclasses import asdict +from pprint import pformat +from typing import Any + +from vllm.config import VllmConfig +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.plugins.io_processors import get_io_processor +from vllm.sampling_params import SamplingParams +from vllm.tokenizers import TokenizerLike +from vllm.v1.engine.exceptions import EngineDeadError + +from vllm_omni.config import OmniModelConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length, try_send_via_connector +from vllm_omni.distributed.ray_utils.utils import try_close_ray +from vllm_omni.engine.input_processor import OmniInputProcessor +from vllm_omni.entrypoints.client_request_state import ClientRequestState +from vllm_omni.entrypoints.log_utils import ( + OrchestratorMetrics, +) +from vllm_omni.entrypoints.omni import OmniBase +from vllm_omni.entrypoints.omni_stage import OmniStage +from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType +from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load +from vllm_omni.entrypoints.utils import ( + get_final_stage_id_for_e2e, +) +from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams + +# Internal imports (our code) +from vllm_omni.lora.request import LoRARequest +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + + +def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handler): + """Weak reference cleanup function for AsyncOmni instances.""" + if stage_list: + for q in stage_in_queues: + try: + q.put_nowait(SHUTDOWN_TASK) + except Exception as e: + logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") + for stage in stage_list: + try: + stage.stop_stage_worker() + except Exception as e: + logger.warning(f"Failed to stop stage worker: {e}") + try_close_ray(ray_pg) + # Cancel output handler + if output_handler is not None: + output_handler.cancel() + + +class AsyncOmni(OmniBase): + """Asynchronous unified entry point supporting multi-stage pipelines for LLM and Diffusion models. + + Similar to the Omni class, but provides an asynchronous interface supporting + asynchronous LLM and Diffusion models. + + Args: + model: Model name or path to load. + **kwargs: Arbitrary keyword arguments. + - stage_configs_path: Optional path to YAML file containing stage + configurations. If None, configurations are loaded from the model. + - log_stats: Whether to enable statistics logging + be written to files with stage-specific suffixes. + - stage_init_timeout: Per-stage init watchdog (seconds). Measured from + when the previous stage finished (possibly a prior Omni run with GPU + reuse/overlap) to when the current stage starts to initialize. + - shm_threshold_bytes: Threshold in bytes for using shared memory + for IPC. Objects larger than this threshold will use shared memory. + - worker_backend: Backend for worker processes. Default is "multi_process". + - ray_address: Address of Ray cluster for Ray backend, if using Ray backend. + - batch_timeout: Timeout in seconds for batching requests within a stage + - init_timeout: Timeout in seconds for waiting for all stages to initialize + - Additional keyword arguments passed to stage engines. + + Example: + >>> async_llm = AsyncOmni(model="Qwen/Qwen2.5-Omni-7B") + >>> async for output in async_llm.generate( + ... prompt="Hello", + ... request_id="req-1", + ... sampling_params_list=[SamplingParams(), SamplingParams()] + ... ): + ... print(output) + """ + + def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: + # Pause/resume control attributes + self._pause_cond: asyncio.Condition = asyncio.Condition() + self._paused: bool = False + + # Request state tracking + self.request_states: dict[str, ClientRequestState] = {} + self.output_handler: asyncio.Task | None = None + + super().__init__(model, **kwargs) + + # Register weak reference cleanup (called on garbage collection) + self._weak_finalizer = weakref.finalize( + self, + _weak_close_cleanup_async, + self.stage_list, + self._stage_in_queues, + self._ray_pg, + self.output_handler, + ) + + def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """Create default diffusion stage configuration.""" + # TODO: here is different from the Omni class. We should merge the two in the future. + cache_backend = kwargs.get("cache_backend", "none") + cache_config = self._normalize_cache_config(cache_backend, kwargs.get("cache_config", None)) + + devices = "0" + if "parallel_config" in kwargs: + parallel_config = kwargs["parallel_config"] + num_devices = kwargs["parallel_config"].world_size + for i in range(1, num_devices): + devices += f",{i}" + else: + ulysses_degree = kwargs.get("ulysses_degree") or 1 + ring_degree = kwargs.get("ring_degree") or 1 + sequence_parallel_size = kwargs.get("sequence_parallel_size") + tensor_parallel_size = kwargs.get("tensor_parallel_size") or 1 + cfg_parallel_size = kwargs.get("cfg_parallel_size") or 1 + if sequence_parallel_size is None: + sequence_parallel_size = ulysses_degree * ring_degree + num_devices = sequence_parallel_size * tensor_parallel_size * cfg_parallel_size + for i in range(1, num_devices): + devices += f",{i}" + parallel_config = DiffusionParallelConfig( + pipeline_parallel_size=1, + data_parallel_size=1, + tensor_parallel_size=tensor_parallel_size, + sequence_parallel_size=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + cfg_parallel_size=cfg_parallel_size, + ) + default_stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": { + "process": True, + "devices": devices, + "max_batch_size": 1, + }, + "engine_args": { + "parallel_config": parallel_config, + "vae_use_slicing": kwargs.get("vae_use_slicing", False), + "vae_use_tiling": kwargs.get("vae_use_tiling", False), + "cache_backend": cache_backend, + "cache_config": cache_config, + "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False), + "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), + "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False), + "layerwise_num_gpu_layers": kwargs.get("layerwise_num_gpu_layers", False), + "enforce_eager": kwargs.get("enforce_eager", False), + }, + "final_output": True, + "final_output_type": "image", + } + ] + default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" + return default_stage_cfg + + def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str, Any]) -> None: + # Store vllm_config received from worker process (may be None for diffusion stages) + vllm_config = result.get("vllm_config") + if vllm_config is not None: + stage.set_vllm_config(vllm_config) + tokenizer = result.get("tokenizer") + if tokenizer is not None: + stage.set_tokenizer(tokenizer) + is_tracing_enabled = result.get("is_tracing_enabled") + if is_tracing_enabled is not None: + stage.set_is_tracing_enabled(is_tracing_enabled) + super()._process_stage_ready(stage, stage_id, result) + + def _wait_for_stages_ready(self, timeout: int = 120) -> None: + """Wait for all stages to report readiness.""" + super()._wait_for_stages_ready(timeout) + for stage in self.stage_list: + if stage.vllm_config is not None and stage.tokenizer is not None: + try: + vllm_config = stage.vllm_config + # Initialize input_processor + # OMNI: OmniInputProcessor creates tokenizer internally from vllm_config + self.input_processor = OmniInputProcessor( + vllm_config=vllm_config, + ) + # Initialize model_config + self.model_config = vllm_config.model_config + # Initialize io_processor + io_processor_plugin = self.model_config.io_processor_plugin + self.io_processor = get_io_processor(vllm_config, io_processor_plugin) + + logger.info( + f"[{self._name}] Initialized input_processor, " + f"io_processor, and model_config from stage-{stage.stage_id}", + ) + break + except Exception as e: + logger.warning( + f"[{self._name}] Failed to initialize processors from stage-{stage.stage_id}: {e}", + ) + # If no LLM stage found, set processors to None + if not hasattr(self, "input_processor") or self.input_processor is None: + logger.warning( + f"[{self._name}] No LLM stage found, processors will not be available. " + "This may cause issues with OpenAIServingModels." + ) + self.input_processor = None + self.io_processor = None + self.model_config = None + + def shutdown(self): + """Shutdown, cleaning up the background proc and IPC. + + Alias for close() method. Cleans up all stage processes + and inter-process communication resources. + """ + if hasattr(self, "_weak_finalizer"): + self._weak_finalizer() + + async def generate( + self, + prompt: OmniPromptType, + request_id: str, + sampling_params_list: Sequence[OmniSamplingParams] | None = None, + *, + output_modalities: list[str] | None = None, + ) -> AsyncGenerator[OmniRequestOutput, None]: + """Generate outputs for the given prompt asynchronously. + + Coordinates multi-stage pipeline through YAML configuration. + Each stage will use AsyncOmniLLM or AsyncOmniDiffusion based on stage_type. + Processes the prompt through all stages in the pipeline and yields + outputs as they become available. Each stage uses its corresponding + sampling parameters from the sampling_params_list. + + Args: + prompt: Prompt to process. Can be a text string, token IDs, + or multimodal prompt. + request_id: Unique identifier for this request + sampling_params_list: List of SamplingParams, one for each stage. + Must have the same length as the number of stages. + If None, uses default sampling params for each stage. + output_modalities: Optional list of output modalities. + + Yields: + OmniRequestOutput objects as they are produced by each stage. + Each output contains the stage_id, final_output_type, and + the request_output from that stage. + + Raises: + ValueError: If sampling_params_list has incorrect length. + """ + # Wait until generation is resumed if the engine is paused. + async with self._pause_cond: + await self._pause_cond.wait_for(lambda: not self._paused) + + logger.debug(f"[{self._name}] generate() called") + try: + # Start output handler on the first call to generate() + self._run_output_handler() + + # TODO: lora_request, trace_headers, priority are not supported yet + if sampling_params_list is None: + sampling_params_list = self.default_sampling_params_list + + if len(sampling_params_list) != len(self.stage_list): + raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + + # Orchestrator keeps stage objects for input derivation + num_stages = len(self.stage_list) + # Track per-request start time for end-to-end timing + _req_start_ts: dict[int, float] = {} + _wall_start_ts: float = time.time() + # _last_finish_ts: float = _wall_start_ts + + # Determine the final stage for E2E stats (highest stage_id with + # final_output=True; fallback to last stage) + final_stage_id_for_e2e = get_final_stage_id_for_e2e( + output_modalities, self.output_modalities, self.stage_list + ) + + # Metrics/aggregation helper + metrics = OrchestratorMetrics( + num_stages, + self._enable_stats, + _wall_start_ts, + ) + req_state = ClientRequestState(request_id) + req_state.metrics = metrics + self.request_states[request_id] = req_state + sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] + task = { + "request_id": request_id, + "engine_inputs": prompt, + "sampling_params": sp0, + } + self.stage_list[0].submit(task) + metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() + _req_start_ts[request_id] = time.time() + logger.info( + f"[{self._name}] Entering scheduling loop: stages={num_stages}, final_stage={final_stage_id_for_e2e}" + ) + if self.async_chunk: + stage_queues = {stage_id: asyncio.Queue() for stage_id in range(num_stages)} + req_state.stage_queues = stage_queues + async for output in self._process_async_results( + request_id, + prompt, + sampling_params_list, + req_state, + metrics, + final_stage_id_for_e2e, + _req_start_ts, + _wall_start_ts, + ): + yield output + else: + async for output in self._process_sequential_results( + request_id, + req_state, + metrics, + final_stage_id_for_e2e, + _req_start_ts, + _wall_start_ts, + sampling_params_list, + prompt, + ): + yield output + + logger.debug(f"[{self._name}] All requests completed") + + # Summarize and print stats + try: + summary = metrics.build_and_log_summary(final_stage_id_for_e2e) + logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) + except Exception as e: + logger.exception(f"[{self._name}] Failed to build/log summary: {e}") + finally: + self.request_states.pop(request_id, None) + except (asyncio.CancelledError, GeneratorExit): + await self.abort(request_id) + logger.info("[AsyncOrchestrator] Request %s aborted.", request_id) + raise + + async def _process_async_results( + self, + request_id: str, + prompt: Any, + sampling_params_list: list[SamplingParams], + req_state: ClientRequestState, + metrics: OrchestratorMetrics, + final_stage_id_for_e2e: int, + req_start_ts: dict[int, float], + wall_start_ts: float, + ) -> AsyncGenerator[OmniRequestOutput, None]: + all_stages_finished = {stage_id: False for stage_id in range(final_stage_id_for_e2e + 1)} + submit_flag = True + while not all(all_stages_finished.values()): + for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): + if all_stages_finished[stage_id]: + continue + try: + result = req_state.stage_queues[stage_id].get_nowait() + except asyncio.QueueEmpty: + await asyncio.sleep(0.001) + continue + + engine_outputs, finished, output_to_yield = self._process_single_result( + result, stage, stage_id, metrics, req_start_ts, wall_start_ts, final_stage_id_for_e2e + ) + if submit_flag and stage_id == 0: + submit_flag = False + prompt_token_ids = engine_outputs.prompt_token_ids + engine_input = copy.deepcopy(prompt) + engine_input["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids) + engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None + for i in range(1, len(self.stage_list)): + task = { + "request_id": request_id, + "engine_inputs": engine_input, + "sampling_params": sampling_params_list[i], + } + self.stage_list[i].submit(task) + metrics.stage_first_ts[i] = time.time() + all_stages_finished[stage_id] = finished + + if output_to_yield: + yield output_to_yield + + async def _process_sequential_results( + self, + request_id: str, + req_state: ClientRequestState, + metrics: OrchestratorMetrics, + final_stage_id_for_e2e: int, + req_start_ts: dict[int, float], + wall_start_ts: float, + sampling_params_list: list[SamplingParams], + prompt: Any, + ) -> AsyncGenerator[OmniRequestOutput, None]: + for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): + finished = False + while not finished: + result = await req_state.queue.get() + assert stage_id == req_state.stage_id + req_id = result.get("request_id") + engine_outputs, finished, output_to_yield = self._process_single_result( + result, stage, stage_id, metrics, req_start_ts, wall_start_ts, final_stage_id_for_e2e + ) + if output_to_yield: + yield output_to_yield + if not isinstance(engine_outputs, list): + engine_outputs = [engine_outputs] + stage.set_engine_outputs(engine_outputs) + # Forward to next stage if there is one + next_stage_id = stage_id + 1 + if next_stage_id <= final_stage_id_for_e2e and finished: + next_stage: OmniStage = self.stage_list[next_stage_id] + next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) + sp_next: SamplingParams = sampling_params_list[next_stage_id] + + # Check if we have a connector for this edge + connector_key = (str(stage_id), str(next_stage_id)) + connector = self.connectors.get(connector_key) + + sent_via_connector = False + if connector: + sent_via_connector = try_send_via_connector( + connector=connector, + stage_id=stage_id, + next_stage_id=next_stage_id, + req_id=req_id, + next_inputs=next_inputs, + sampling_params=sp_next, + original_prompt=prompt, + next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, + metrics=metrics, + ) + + if not sent_via_connector: + # Fallback logic removed as we now enforce connector usage. + # If no connector is found or send fails, we log an error and raise, + # because continuing would cause the request to be silently dropped + # and the orchestrator to hang waiting for completion. + error_msg = ( + f"[{self._name}] Failed to send request {req_id} to stage-{next_stage_id} via connector. " + "Configure a connector for this edge or inspect connector logs for details." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + logger.debug(f"[{self._name}] Forwarded request {req_id} to stage-{next_stage_id}") + else: + logger.debug(f"[{self._name}] Request {req_id} fully completed") + + def _process_single_result( + self, + result: dict[str, Any], + stage: OmniStage, + stage_id: int, + metrics: OrchestratorMetrics, + req_start_ts: dict[int, float], + wall_start_ts: float, + final_stage_id_for_e2e: int, + ) -> tuple[Any, bool, OmniRequestOutput | None]: + """ + Process a single result dictionary from a stage. + Returns: + engine_outputs: The decoded outputs. + finished: Whether the stage processing is finished for this request. + output_to_yield: An OmniRequestOutput to yield, or None. + """ + req_id = result.get("request_id") + if "error" in result: + logger.error( + f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", + ) + raise RuntimeError(result) + + engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + if isinstance(engine_outputs, list): + engine_outputs = engine_outputs[0] + + finished = engine_outputs.finished + + # Mark last output time + metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) + + try: + _m = asdict(result.get("metrics")) + if _m is not None and finished: + metrics.on_stage_metrics(stage_id, req_id, _m) + except Exception as e: + logger.exception( + f"[{self._name}] Failed to process metrics for stage {stage_id}, req {req_id}: {e}", + ) + + logger.debug( + f"[{self._name}] Stage-{stage_id} completed request {req_id}; forwarding or finalizing", + ) + + output_to_yield = None + if getattr(stage, "final_output", False): + logger.debug(f"[{self._name}] Request {req_id} finalized at stage-{stage_id}") + + # Finalize request metrics if this is the E2E final stage and it's finished + try: + rid_key = str(req_id) + if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done and finished: + metrics.on_finalize_request( + stage_id, + req_id, + req_start_ts.get(req_id, wall_start_ts), + ) + except Exception as e: + logger.exception( + f"[{self._name}] Finalize request handling error for req {req_id} at stage {stage_id}: {e}", + ) + + # Construct output to yield + images = [] + if stage.final_output_type == "image": + if isinstance(engine_outputs, OmniRequestOutput) and engine_outputs.images: + images = engine_outputs.images + elif hasattr(engine_outputs, "images") and engine_outputs.images: + images = engine_outputs.images + + if stage.final_output_type == "image": + output_to_yield = OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + images=images, + ) + else: + output_to_yield = OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + ) + + return engine_outputs, finished, output_to_yield + + def _run_output_handler(self) -> None: + if self.output_handler is not None: + return + + stage_list = self.stage_list + request_states = self.request_states + + async def output_handler(): + try: + while True: + idle = True + for stage_id, stage in enumerate(stage_list): + result = stage.try_collect() + if result is None: + continue + idle = False + if result.get("type") == "stage_ready": + # Only happens when stage is initialized slower than expected, + # so we wait for a short time and try again + await asyncio.sleep(0.05) + continue + req_id = result.get("request_id") + req_state = request_states.get(req_id) + if req_state is None: + logger.debug( + f"[{self._name}] Request may have been aborted; \ + dropping output for req {req_id} at stage-{stage_id}" + ) + continue + if hasattr(req_state, "stage_queues") and stage_id in req_state.stage_queues: + await req_state.stage_queues[stage_id].put(result) + else: + # Fallback to old behavior for compatibility + await req_state.queue.put(result) + req_state.stage_id = stage_id + if idle: + await asyncio.sleep(0.001) # Avoid CPU overload when idle + else: + await asyncio.sleep(0) + except Exception as e: + logger.exception("AsyncOmni output_handler failed.") + for req_state in request_states.values(): + error_msg = {"request_id": req_state.request_id, "error": str(e)} + # Send error to all stage queues + if hasattr(req_state, "stage_queues"): + for queue in req_state.stage_queues.values(): + await queue.put(error_msg) + else: + await req_state.queue.put(error_msg) + error_msg = {"request_id": req_state.request_id, "error": str(e)} + self.output_handler = None # Make possible for restart + + self.output_handler = asyncio.create_task(output_handler()) + + @property + def is_running(self) -> bool: + # Is None before the loop is started. + return len(self._stage_in_queues) > 0 + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return not self.is_running + + @property + def _name(self) -> str: + return "AsyncOrchestrator" + + @property + def is_async(self) -> bool: + return True + + @property + def dead_error(self) -> BaseException: + return EngineDeadError() + + async def abort(self, request_id: str | Iterable[str]) -> None: + abort_task = {"type": OmniStageTaskType.ABORT, "request_id": request_id} + for stage in self.stage_list: + stage.submit(abort_task) + return None + + async def get_vllm_config(self) -> VllmConfig: + for stage in self.stage_list: + if stage.is_comprehension: + # Use the vllm_config received from worker process + if stage.vllm_config is not None: + return stage.vllm_config + return None + + async def get_model_config(self) -> OmniModelConfig: + for stage in self.stage_list: + if stage.is_comprehension: + # Use the vllm_config received from worker process + if stage.vllm_config is not None: + return stage.vllm_config.model_config + return None + + async def get_input_preprocessor(self) -> InputPreprocessor: + return None + + async def get_tokenizer(self) -> TokenizerLike: + for stage in self.stage_list: + if stage.is_comprehension: + return stage.tokenizer + return None + + async def is_tracing_enabled(self) -> bool: + for stage in self.stage_list: + if stage.is_comprehension: + return stage.is_tracing_enabled + return False + + @property + def renderer(self): + """Return the renderer from input_processor if available. + + OMNI: Required by upstream OpenAIServingModels.__init__ which + accesses engine_client.renderer. + """ + return self.input_processor.renderer + + async def do_log_stats(self) -> None: + pass + + async def check_health(self) -> None: + pass + + async def reset_mm_cache(self) -> None: + pass + + async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + pass + + async def sleep(self, level: int = 1) -> None: + pass + + async def wake_up(self, tags: list[str] | None = None) -> None: + pass + + async def is_sleeping(self) -> bool: + """Check whether the engine is sleeping""" + return False + + async def add_lora(self, lora_request: LoRARequest) -> bool: + """Load a new LoRA adapter into the engine for future requests.""" + return False + + async def encode( + self, + *args, + **kwargs, + ): + """Generate outputs for a request from a pooling model.""" + raise NotImplementedError("encode() is not implemented for AsyncOmni") + + async def start_profile(self, stages: list[int] | None = None) -> None: + """Start profiling for specified stages. + + Async wrapper around the base implementation for API consistency. + + Args: + stages: List of stage IDs to start profiling. If None, starts + profiling for all stages that have profiling enabled. + + Example: + >>> await async_omni.start_profile() + >>> async for output in async_omni.generate(...): + ... pass + >>> await async_omni.stop_profile() + """ + super().start_profile(stages) + + async def stop_profile(self, stages: list[int] | None = None) -> None: + """Stop profiling for specified stages. + + Async wrapper around the base implementation for API consistency. + + Args: + stages: List of stage IDs to stop profiling. If None, stops + profiling for all stages. + + Example: + >>> await async_omni.start_profile() + >>> async for output in async_omni.generate(...): + ... pass + >>> await async_omni.stop_profile() + """ + super().stop_profile(stages) + + async def pause_generation( + self, + *, + wait_for_inflight_requests: bool = False, + clear_cache: bool = True, + ) -> None: + """ + Pause generation to allow model weight updates. + + New generation/encoding requests are blocked until resume. + + Args: + wait_for_inflight_requests: When ``True`` waits for in-flight + requests to finish before pausing. When ``False`` (default), + immediately aborts any in-flight requests. + clear_cache: Whether to clear KV cache and prefix cache after + draining. Set to ``False`` to preserve cache for faster resume. + Default is ``True`` (clear caches). + """ + + async with self._pause_cond: + if self._paused: + return + self._paused = True + + # Note: AsyncOmni uses a stage-based architecture without a central + # output_processor. For now, we simply set the pause flag and let + # new requests wait. In-flight requests will complete naturally. + # TODO: Implement request abortion for stages if needed. + + # Clear cache if requested + if clear_cache: + await self.reset_prefix_cache() + await self.reset_mm_cache() + + async def resume_generation(self) -> None: + """Resume generation after :meth:`pause_generation`.""" + + async with self._pause_cond: + self._paused = False + self._pause_cond.notify_all() # Wake up all waiting requests + + async def is_paused(self) -> bool: + """Return whether the engine is currently paused.""" + + async with self._pause_cond: + return self._paused diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..535f04f7d2e507d8dc7288a34ae30d353d175cbb --- /dev/null +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Async entrypoint for vLLM-Omni diffusion model inference. + +Provides an asynchronous interface for running diffusion models, +enabling concurrent request handling and streaming generation. +""" + +import asyncio +import uuid +from collections.abc import AsyncGenerator, Iterable +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_hf_file_to_dict + +from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig +from vllm_omni.diffusion.diffusion_engine import DiffusionEngine +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType +from vllm_omni.lora.request import LoRARequest +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + + +class AsyncOmniDiffusion: + """Async entry point for vLLM-Omni diffusion model inference. + + This class provides an asynchronous interface for running diffusion models, + enabling concurrent request handling. It wraps the DiffusionEngine and + provides async methods for image generation. + + Args: + model: Model name or path to load + od_config: Optional OmniDiffusionConfig. If not provided, it will be + created from kwargs + **kwargs: Additional keyword arguments passed to OmniDiffusionConfig + + Example: + >>> async_diffusion = AsyncOmniDiffusion(model="Qwen/Qwen-Image") + >>> result = await async_diffusion.generate( + ... prompt="A beautiful sunset over the ocean", + ... request_id="req-1", + ... ) + >>> print(result.images) + """ + + def __init__( + self, + model: str, + od_config: OmniDiffusionConfig | None = None, + **kwargs: Any, + ): + self.model = model + + # Capture stage info from kwargs before they might be filtered out + stage_id = kwargs.get("stage_id") + engine_input_source = kwargs.get("engine_input_source") + + # Build config + if od_config is None: + od_config = OmniDiffusionConfig.from_kwargs(model=model, **kwargs) + elif isinstance(od_config, dict): + # If config is dict, check it too (priority to kwargs if both exist) + if stage_id is None: + stage_id = od_config.get("stage_id") + if engine_input_source is None: + engine_input_source = od_config.get("engine_input_source") + od_config = OmniDiffusionConfig.from_kwargs(**od_config) + + self.od_config = od_config + + # Inject stage info into omni_kv_config if present + if stage_id is not None: + self.od_config.omni_kv_config.setdefault("stage_id", stage_id) + if engine_input_source is not None: + self.od_config.omni_kv_config.setdefault("engine_input_source", engine_input_source) + + try: + config_dict = get_hf_file_to_dict("model_index.json", od_config.model) + od_config.model_class_name = config_dict.get("_class_name", None) + od_config.update_multimodal_support() + + tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + except (AttributeError, OSError, ValueError): + cfg = get_hf_file_to_dict("config.json", od_config.model) + if cfg is None: + raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") + + model_type = cfg.get("model_type") + architectures = cfg.get("architectures") or [] + if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: + od_config.model_class_name = "BagelPipeline" + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() + + # Initialize engine + self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config) + + # Thread pool for running sync engine in async context + self._executor = ThreadPoolExecutor(max_workers=1) + self._closed = False + + logger.info("AsyncOmniDiffusion initialized with model: %s", model) + + async def generate( + self, + prompt: OmniPromptType, + sampling_params: OmniDiffusionSamplingParams, + request_id: str | None = None, + lora_request: LoRARequest | None = None, + ) -> OmniRequestOutput: + """Generate images asynchronously from a text prompt. + + Args: + prompt: Text prompt describing the desired image + sampling_params: Sampling parameters + request_id: Optional unique identifier for tracking the request + + Returns: + OmniRequestOutput containing generated images + + Raises: + RuntimeError: If generation fails + """ + if request_id is None: + request_id = f"diff-{uuid.uuid4().hex[:16]}" + + if sampling_params.guidance_scale: + sampling_params.guidance_scale_provided = True + + if lora_request is not None: + sampling_params.lora_request = lora_request + + request = OmniDiffusionRequest( + prompts=[prompt], + sampling_params=sampling_params, + request_ids=[request_id], + ) + + logger.debug("Starting generation for request %s", request_id) + + # Run engine in thread pool + loop = asyncio.get_event_loop() + try: + # In async mode, only a single request is submitted at a time + result = await loop.run_in_executor( + self._executor, + self.engine.step, + request, + ) + result = result[0] + except Exception as e: + logger.error("Generation failed for request %s: %s", request_id, e) + raise RuntimeError(f"Diffusion generation failed: {e}") from e + + # Update request_id if needed + if not result.request_id: + result.request_id = request_id + return result + + async def generate_stream( + self, + prompt: str, + request_id: str | None = None, + **kwargs: Any, + ) -> AsyncGenerator[OmniRequestOutput, None]: + """Generate images with streaming progress updates. + + Currently, diffusion models don't support true streaming, so this + yields a single result after generation completes. Future implementations + may support step-by-step progress updates. + + Args: + prompt: Text prompt describing the desired image + request_id: Optional unique identifier for tracking the request + **kwargs: Additional generation parameters + + Yields: + OmniRequestOutput with generation progress/results + """ + result = await self.generate(prompt=prompt, request_id=request_id, **kwargs) + yield result + + def close(self) -> None: + """Close the engine and release resources. + + Should be called when done using the AsyncOmniDiffusion instance. + """ + if self._closed: + return + self._closed = True + + try: + self.engine.close() + except Exception as e: + logger.warning("Error closing diffusion engine: %s", e) + + try: + self._executor.shutdown(wait=False) + except Exception as e: + logger.warning("Error shutting down executor: %s", e) + + logger.info("AsyncOmniDiffusion closed") + + def shutdown(self) -> None: + """Alias for close() method.""" + self.close() + + def __del__(self) -> None: + """Best-effort cleanup on deletion.""" + try: + self.close() + except Exception: + pass + + async def abort(self, request_id: str | Iterable[str]) -> None: + """Abort a request.""" + self.engine.abort(request_id) + + @property + def is_running(self) -> bool: + """Check if the engine is running.""" + return not self._closed + + @property + def is_stopped(self) -> bool: + """Check if the engine is stopped.""" + return self._closed + + async def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "remove_lora", + None, + (adapter_id,), + {}, + None, + ) + return all(results) if isinstance(results, list) else results + + async def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """Add a LoRA adapter""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "add_lora", + None, + (), + {"lora_request": lora_request, "lora_scale": lora_scale}, + None, + ) + return all(results) if isinstance(results, list) else results + + async def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "list_loras", + None, + (), + {}, + None, + ) + # collective_rpc returns list from workers; flatten unique ids + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + merged.update(part or []) + return sorted(merged) + + async def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "pin_lora", + None, + (), + {"adapter_id": lora_id}, + None, + ) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..b557c07dd57042371f9dec6e464b37955572f6fb --- /dev/null +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +import socket +from typing import TYPE_CHECKING + +import torch +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tracing import init_tracer +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.usage.usage_lib import UsageContext +from vllm.utils.func_utils import deprecate_kwargs +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager + +from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs +from vllm_omni.engine.input_processor import OmniInputProcessor +from vllm_omni.engine.output_processor import MultimodalOutputProcessor + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class AsyncOmniLLM(AsyncLLM): + """Async single-stage LLM engine for use within a stage worker process. + + This class extends the base vLLM AsyncLLM class with omni-specific + processors for handling multimodal inputs and outputs. It is used + internally by AsyncOmniStage workers and should not be instantiated + directly by users. + + Args: + engine_args: AsyncOmniEngineArgs containing engine configuration + vllm_config: Global vLLM configuration + executor_class: Executor implementation class, e.g. MultiprocExecutor + log_stats: Whether to log statistics + usage_context: Usage context of the LLM (default: ENGINE_CONTEXT) + mm_registry: Multi-modal registry for processing multimodal inputs + use_cached_outputs: Whether to use cached outputs + log_requests: Whether to log requests + start_engine_loop: Whether to start the engine loop automatically + stat_loggers: Customized stat loggers for the engine. + If not provided, default stat loggers will be used. + Note: Stat logger interface may change in V1. + client_addresses: Optional dictionary mapping client names to addresses + client_count: Total number of clients (default: 1) + client_index: Index of this client (default: 0) + """ + + def __init__( + self, + engine_args: AsyncOmniEngineArgs, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + log_requests: bool = True, + start_engine_loop: bool = True, + stat_loggers: list[StatLoggerFactory] | None = None, + client_addresses: dict[str, str] | None = None, + client_count: int = 1, + client_index: int = 0, + ) -> None: + """ + Create an AsyncOmniLLM. + + Args: + vllm_config: global configuration. + executor_class: an Executor impl, e.g. MultiprocExecutor. + log_stats: Whether to log stats. + usage_context: Usage context of the LLM. + mm_registry: Multi-modal registry. + use_cached_outputs: Whether to use cached outputs. + log_requests: Whether to log requests. + start_engine_loop: Whether to start the engine loop. + stat_loggers: customized stat loggers for the engine. + If not provided, default stat loggers will be used. + PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE + IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. + + Returns: + None + """ + # Ensure we can serialize custom transformer configs + maybe_register_config_serialize_by_value() + + self.model_config = vllm_config.model_config + self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config + self.log_requests = log_requests + + self.log_stats = log_stats or (stat_loggers is not None) + if not log_stats and stat_loggers is not None: + logger.info( + "AsyncLLM created with log_stats=False and non-empty custom logger list; " + "enabling logging without default stat loggers" + ) + + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + # Tokenizer (+ ensure liveness if running in another process). + tokenizer = cached_tokenizer_from_config(model_config=vllm_config.model_config) + + # InputProcessor (converts Inputs --> EngineCoreRequests). + self.input_processor = OmniInputProcessor( + vllm_config=vllm_config, + mm_registry=mm_registry, + ) + + # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). + self.output_processor = MultimodalOutputProcessor( + tokenizer=tokenizer, + log_stats=self.log_stats, + engine_core_output_type=engine_args.engine_output_type, + ) + + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer + + # Pause / resume state for async RL workflows. + self._pause_cond = asyncio.Condition() + self._paused = False + + # EngineCore (starts the engine in background process). + self.engine_core = EngineCoreClient.make_async_mp_client( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=self.log_stats, + client_addresses=client_addresses, + client_count=client_count, + client_index=client_index, + ) + + # Loggers. + self.logger_manager: StatLoggerManager | None = None + if self.log_stats: + self.logger_manager = StatLoggerManager( + vllm_config=vllm_config, + engine_idxs=self.engine_core.engine_ranks_managed, + custom_stat_loggers=stat_loggers, + enable_default_loggers=log_stats, + client_count=client_count, + ) + self.logger_manager.log_engine_initialized() + + self.output_handler: asyncio.Task | None = None + try: + # Start output handler eagerly if we are in the asyncio eventloop. + asyncio.get_running_loop() + self._run_output_handler() + except RuntimeError: + pass + + if envs.VLLM_TORCH_PROFILER_DIR and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: + logger.info( + "Torch profiler enabled. AsyncOmniLLM CPU traces will be collected under %s", + envs.VLLM_TORCH_PROFILER_DIR, + ) + worker_name = f"{socket.gethostname()}_{os.getpid()}.async_omni_llm" + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + envs.VLLM_TORCH_PROFILER_DIR, + worker_name=worker_name, + use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, + ), + ) + else: + self.profiler = None + + @classmethod + @deprecate_kwargs( + "disable_log_requests", + additional_message=("This argument will have no effect. Use `enable_log_requests` instead."), + ) + def from_vllm_config( + cls, + vllm_config: VllmConfig, + engine_args: AsyncOmniEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: list[StatLoggerFactory] | None = None, + enable_log_requests: bool = False, + disable_log_stats: bool = False, + client_addresses: dict[str, str] | None = None, + client_count: int = 1, + client_index: int = 0, + disable_log_requests: bool = True, # Deprecated, will be removed + ) -> "AsyncLLM": + # Create the LLMEngine. + return cls( + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + start_engine_loop=start_engine_loop, + stat_loggers=stat_loggers, + log_requests=enable_log_requests, + log_stats=not disable_log_stats, + usage_context=usage_context, + client_addresses=client_addresses, + client_count=client_count, + client_index=client_index, + engine_args=engine_args, + ) diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ca15305c51f23cfdcfec400d7a7885b02d1a1043 --- /dev/null +++ b/vllm_omni/entrypoints/chat_utils.py @@ -0,0 +1,259 @@ +from collections.abc import Awaitable, Iterable +from typing import Any, cast + +import numpy as np +from openai.types.chat import ChatCompletionContentPartTextParam +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + AsyncMultiModalContentParser, + AsyncMultiModalItemTracker, + BaseMultiModalContentParser, + BaseMultiModalItemTracker, + ChatCompletionContentPartParam, + ChatCompletionMessageParam, + ChatTemplateContentFormat, + ConversationMessage, + MultiModalDataDict, + MultiModalUUIDDict, + _AssistantParser, + _ContentPart, + _get_full_multimodal_text_prompt, + _parse_chat_message_content_part, + _postprocess_messages, + _ToolParser, +) + + +class OmniAsyncMultiModalItemTracker(AsyncMultiModalItemTracker): + def create_parser(self) -> "BaseMultiModalContentParser": + return OmniAsyncMultiModalContentParser(self) + + +class OmniAsyncMultiModalContentParser(AsyncMultiModalContentParser): + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__(tracker=tracker) + self._mm_processor_kwargs: dict[str, Any] | None = None + + def set_mm_processor_kwargs(self, mm_processor_kwargs: dict[str, Any] | None) -> None: + """Set mm_processor_kwargs for use in parsing.""" + self._mm_processor_kwargs = mm_processor_kwargs + + def parse_video(self, video_url: str | None, uuid: str | None = None) -> None: + # OMNI: Follow upstream async pattern - create coroutine that resolves to (data, uuid) + coro = self._video_with_uuid_async(video_url, uuid) + placeholder = self._tracker.add("video", coro) + self._add_placeholder("video", placeholder) + + # Extract audio from video if use_audio_in_video is True + if video_url and self._mm_processor_kwargs and self._mm_processor_kwargs.get("use_audio_in_video", False): + audio_coro = self._audio_from_video_with_uuid_async(video_url, uuid) + audio_placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder("audio", audio_placeholder) + + async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None): + """Fetch video and return (video, uuid) tuple.""" + video = await self._connector.fetch_video_async(video_url=video_url) if video_url else None + return video, uuid + + async def _audio_from_video_with_uuid_async(self, video_url: str, uuid: str | None): + """Extract audio from video and return (audio, uuid) tuple.""" + audio = await self._extract_audio_from_video_async(video_url) + return audio, uuid + + async def _extract_audio_from_video_async(self, video_url: str) -> tuple[np.ndarray, int | float]: + """ + Extract audio from video URL using librosa. + Returns tuple of (audio_array, sample_rate) compatible with audio format. + + All blocking I/O operations are run in a thread pool to avoid blocking the event loop. + """ + import asyncio + import os + import tempfile + from urllib.parse import urlparse + + # Parse URL to determine type + parsed_url = urlparse(video_url) + temp_video_file_path = None + + def _download_video_sync(url: str) -> bytes: + """Synchronous video download - runs in thread pool.""" + from urllib.request import urlopen + + return urlopen(url).read() + + def _write_temp_file_sync(data: bytes, suffix: str) -> str: + """Synchronous temp file write - runs in thread pool.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: + temp_file.write(data) + return temp_file.name + + def _load_audio_sync(file_path: str) -> tuple[np.ndarray, int | float]: + """Synchronous audio loading with librosa - runs in thread pool.""" + import librosa + + return librosa.load(file_path, sr=16000) + + def _cleanup_file_sync(file_path: str) -> None: + """Synchronous file deletion - runs in thread pool.""" + try: + if os.path.exists(file_path): + os.unlink(file_path) + except OSError: + pass + + try: + if parsed_url.scheme in ("http", "https"): + # Download video from HTTP/HTTPS URL asynchronously + video_data = await asyncio.to_thread(_download_video_sync, video_url) + # Write temp file asynchronously + temp_video_file_path = await asyncio.to_thread(_write_temp_file_sync, video_data, ".mp4") + elif parsed_url.scheme == "file": + # Use file path directly (handle Windows paths) + from urllib.request import url2pathname + + temp_video_file_path = url2pathname(parsed_url.path) + elif parsed_url.scheme == "data": + # Handle data URL (base64 encoded video) + import base64 + + header, data = video_url.split(",", 1) + video_data = base64.b64decode(data) + # Write temp file asynchronously + temp_video_file_path = await asyncio.to_thread(_write_temp_file_sync, video_data, ".mp4") + else: + # Assume it's a local file path + temp_video_file_path = video_url + + # Extract audio using librosa asynchronously (CPU-intensive, runs in thread pool) + audio_array, sample_rate = await asyncio.to_thread(_load_audio_sync, temp_video_file_path) + + return audio_array, sample_rate + finally: + # Clean up temporary file if we created one (asynchronously) + if temp_video_file_path and parsed_url.scheme in ("http", "https", "data"): + await asyncio.to_thread(_cleanup_file_sync, temp_video_file_path) + + +def parse_chat_messages_futures( + messages: list[ChatCompletionMessageParam], + model_config: ModelConfig, + content_format: ChatTemplateContentFormat, + mm_processor_kwargs: dict[str, Any] | None = None, +) -> tuple[ + list[ConversationMessage], + Awaitable[tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]], +]: + """Parse chat messages and return conversation with multimodal data future. + + OMNI: Updated to use upstream vLLM v0.15.0 API where resolve_items() + returns both mm_data and mm_uuids together as a tuple. + + Returns: + Tuple of (conversation, mm_future) where mm_future resolves to + (mm_data, mm_uuids) when awaited. + """ + conversation: list[ConversationMessage] = [] + mm_tracker = OmniAsyncMultiModalItemTracker(model_config) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + interleave_strings=( + content_format == "string" + and model_config.multimodal_config is not None + and model_config.multimodal_config.interleave_mm_strings + ), + mm_processor_kwargs=mm_processor_kwargs, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + # OMNI: Use upstream resolve_items() which returns (mm_data, mm_uuids) tuple + return conversation, mm_tracker.resolve_items() + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, + content_format: ChatTemplateContentFormat, + interleave_strings: bool, + mm_processor_kwargs: dict[str, Any] | None = None, +) -> list[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ChatCompletionContentPartTextParam(type="text", text=content)] + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + wrap_dicts=(content_format == "openai"), + interleave_strings=interleave_strings, + mm_processor_kwargs=mm_processor_kwargs, + ) + + for result_msg in result: + if role == "assistant": + parsed_msg = _AssistantParser(message) + + # The 'tool_calls' is not None check ensures compatibility. + # It's needed only if downstream code doesn't strictly + # follow the OpenAI spec. + if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, + *, + wrap_dicts: bool, + interleave_strings: bool, + mm_processor_kwargs: dict[str, Any] | None = None, +) -> list[ConversationMessage]: + content = list[_ContentPart]() + + mm_parser = mm_tracker.create_parser() + # Set mm_processor_kwargs if parser supports it + if hasattr(mm_parser, "set_mm_processor_kwargs"): + mm_parser.set_mm_processor_kwargs(mm_processor_kwargs) + + for part in parts: + parse_res = _parse_chat_message_content_part( + part, + mm_parser, + wrap_dicts=wrap_dicts, + interleave_strings=interleave_strings, + ) + if parse_res: + content.append(parse_res) + + if wrap_dicts: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, content=content)] # type: ignore + texts = cast(list[str], content) + mm_placeholder_storage = mm_parser.mm_placeholder_storage() + if mm_placeholder_storage: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, texts, interleave_strings) + else: + text_prompt = "\n".join(texts) + + return [ConversationMessage(role=role, content=text_prompt)] diff --git a/vllm_omni/entrypoints/cli/__init__.py b/vllm_omni/entrypoints/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ffba613055802ccc06f8ca1b180de0d9044f947 --- /dev/null +++ b/vllm_omni/entrypoints/cli/__init__.py @@ -0,0 +1,13 @@ +"""CLI helpers for vLLM-Omni entrypoints.""" + +# To ensure patch imports work properly, disable unused import checks +# ruff: noqa: E402, F401 +# isort: off +from vllm_omni.benchmarks.patch import patch +# isort: on + +from vllm_omni.entrypoints.cli.benchmark.serve import OmniBenchmarkServingSubcommand + +from .serve import OmniServeCommand + +__all__ = ["OmniServeCommand", "OmniBenchmarkServingSubcommand"] diff --git a/vllm_omni/entrypoints/cli/benchmark/__init__.py b/vllm_omni/entrypoints/cli/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/entrypoints/cli/benchmark/base.py b/vllm_omni/entrypoints/cli/benchmark/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6f97eb1e8e9dde6ae3cce8b8ed96b4eb2fd679 --- /dev/null +++ b/vllm_omni/entrypoints/cli/benchmark/base.py @@ -0,0 +1,23 @@ +import argparse + +from vllm.entrypoints.cli.types import CLISubcommand + + +class OmniBenchmarkSubcommandBase(CLISubcommand): + """The base class of subcommands for vllm bench.""" + + help: str + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + """Add the CLI arguments to the parser.""" + raise NotImplementedError + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Run the benchmark. + + Args: + args: The arguments to the command. + """ + raise NotImplementedError diff --git a/vllm_omni/entrypoints/cli/benchmark/main.py b/vllm_omni/entrypoints/cli/benchmark/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8880e35c7cf1bf6204db295c5a1bec262e8e2ba2 --- /dev/null +++ b/vllm_omni/entrypoints/cli/benchmark/main.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import argparse +import typing + +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG + +from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + + +class OmniBenchmarkSubcommand(CLISubcommand): + """The `bench` subcommand for the vLLM CLI.""" + + name = "bench" + help = "vLLM-omni bench subcommand." + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + args.dispatch_function(args) + + def validate(self, args: argparse.Namespace) -> None: + pass + + def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + bench_parser = subparsers.add_parser( + self.name, description=self.help, usage=f"vllm {self.name} <bench_type> [options]" + ) + bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type") + + for cmd_cls in OmniBenchmarkSubcommandBase.__subclasses__(): + cmd_subparser = bench_subparsers.add_parser( + cmd_cls.name, + help=cmd_cls.help, + description=cmd_cls.help, + usage=f"vllm {self.name} {cmd_cls.name} [--omni] [options]", + ) + cmd_subparser.add_argument( + "--omni", + action="store_true", + help="Enable benchmark-Omni mode (always enabled for omni commands)", + ) + cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd) + cmd_cls.add_cli_args(cmd_subparser) + + cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=f"{self.name} {cmd_cls.name}") + + return bench_parser + + +def cmd_init() -> list[CLISubcommand]: + return [OmniBenchmarkSubcommand()] diff --git a/vllm_omni/entrypoints/cli/benchmark/serve.py b/vllm_omni/entrypoints/cli/benchmark/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..906e8851a4a4ff9c555fd179419b85717ce2bf81 --- /dev/null +++ b/vllm_omni/entrypoints/cli/benchmark/serve.py @@ -0,0 +1,51 @@ +import argparse + +from vllm.benchmarks.serve import add_cli_args + +from vllm_omni.benchmarks.serve import main +from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase + + +class OmniBenchmarkServingSubcommand(OmniBenchmarkSubcommandBase): + """The `serve` subcommand for vllm bench.""" + + name = "serve" + help = "Benchmark the online serving throughput." + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + for action in parser._actions: + if action.dest == "percentile_metrics": + action.help = ( + "Comma-separated list of selected metrics to report percentiles." + "This argument specifies the metrics to report percentiles." + 'Allowed metric names are "ttft", "tpot", "itl", "e2el", "audio_ttfp", "audio_rtf". ' + ) + if action.dest == "random_mm_limit_mm_per_prompt": + action.help = ( + "Per-modality hard caps for items attached per request, e.g. " + '\'{"image": 3, "video": 0, "audio": 1}\'. The sampled per-request item ' + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + ) + if action.dest == "random_mm_bucket_config": + action.help = ( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 3 modalities: audio, images and videos. " + "A bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 16): 0.4, (0, 1, 5): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: videos with resolution 720x1280 and 16 frames " + "Third item: audios with 1s duration and 5 channels w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + ) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) diff --git a/vllm_omni/entrypoints/cli/main.py b/vllm_omni/entrypoints/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..629a4641cce6bf51984d20ef5c8939e0ff4ed9b0 --- /dev/null +++ b/vllm_omni/entrypoints/cli/main.py @@ -0,0 +1,59 @@ +""" +CLI entry point for vLLM-Omni that intercepts vLLM commands. +""" + +import importlib.metadata +import sys + + +def main(): + """Main CLI entry point that intercepts vLLM commands.""" + # Check if --omni flag is present + if "--omni" not in sys.argv: + from vllm.entrypoints.cli.main import main as vllm_main + + vllm_main() + return + else: + from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup + from vllm.utils.argparse_utils import FlexibleArgumentParser + + import vllm_omni.entrypoints.cli.benchmark.main + import vllm_omni.entrypoints.cli.serve + + CMD_MODULES = [ + vllm_omni.entrypoints.cli.serve, + vllm_omni.entrypoints.cli.benchmark.main, + ] + + cli_env_setup() + + parser = FlexibleArgumentParser( + description="vLLM OMNI CLI", + epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), + ) + parser.add_argument( + "-v", + "--version", + action="version", + version=importlib.metadata.version("vllm_omni"), + ) + subparsers = parser.add_subparsers(required=False, dest="subparser") + cmds = {} + for cmd_module in CMD_MODULES: + new_cmds = cmd_module.cmd_init() + for cmd in new_cmds: + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args = parser.parse_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..acefc081b0bb705b9cc77ea22bb7c5027966891f --- /dev/null +++ b/vllm_omni/entrypoints/cli/serve.py @@ -0,0 +1,255 @@ +""" +Omni serve command for vLLM-Omni. + +Supports both multi-stage LLM models (e.g., Qwen2.5-Omni) and +diffusion models (e.g., Qwen-Image) through the same CLI interface. +""" + +import argparse + +import uvloop +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG +from vllm.logger import init_logger +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni.entrypoints.openai.api_server import omni_run_server + +logger = init_logger(__name__) + +DESCRIPTION = """Launch a local OpenAI-compatible API server to serve Omni models +via HTTP. Supports both multi-stage LLM models and diffusion models. + +The server automatically detects the model type: +- LLM models: Served via /v1/chat/completions endpoint +- Diffusion models: Served via /v1/images/generations endpoint + +Examples: + # Start an Omni LLM server + vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 + + # Start a diffusion model server + vllm serve Qwen/Qwen-Image --omni --port 8091 + +Search by using: `--help=<ConfigGroup>` to explore options by section (e.g., +--help=OmniConfig) + Use `--help=all` to show all available flags at once. +""" + + +class OmniServeCommand(CLISubcommand): + """The `serve` subcommand for the vLLM CLI.""" + + name = "serve" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + # If model is specified in CLI (as positional arg), it takes precedence + if hasattr(args, "model_tag") and args.model_tag is not None: + args.model = args.model_tag + + uvloop.run(omni_run_server(args)) + + def validate(self, args: argparse.Namespace) -> None: + # Skip validation for diffusion models as they have different requirements + from vllm_omni.diffusion.utils.hf_utils import is_diffusion_model + + model = getattr(args, "model_tag", None) or getattr(args, "model", None) + if model and is_diffusion_model(model): + logger.info("Detected diffusion model: %s", model) + return + validate_parsed_serve_args(args) + + def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + self.name, + description=DESCRIPTION, + usage="vllm serve [model_tag] --omni [options]", + ) + + serve_parser = make_arg_parser(serve_parser) + serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) + + # Create OmniConfig argument group for omni-related parameters + # This ensures the parameters appear in --help output + omni_config_group = serve_parser.add_argument_group( + title="OmniConfig", description="Configuration for vLLM-Omni multi-stage and diffusion models." + ) + + omni_config_group.add_argument( + "--omni", + action="store_true", + help="Enable vLLM-Omni mode for multi-modal and diffusion models", + ) + omni_config_group.add_argument( + "--stage-configs-path", + type=str, + default=None, + help="Path to the stage configs file. If not specified, the stage configs will be loaded from the model.", + ) + omni_config_group.add_argument( + "--stage-init-timeout", + type=int, + default=300, + help="The timeout for initializing a single stage in seconds (default: 300)", + ) + omni_config_group.add_argument( + "--init-timeout", + type=int, + default=600, + help="The timeout for initializing the stages.", + ) + omni_config_group.add_argument( + "--shm-threshold-bytes", + type=int, + default=65536, + help="The threshold for the shared memory size.", + ) + omni_config_group.add_argument( + "--log-stats", + action="store_true", + help="Enable logging the stats.", + ) + omni_config_group.add_argument( + "--log-file", + type=str, + default=None, + help="The path to the log file.", + ) + omni_config_group.add_argument( + "--batch-timeout", + type=int, + default=10, + help="The timeout for the batch.", + ) + omni_config_group.add_argument( + "--worker-backend", + type=str, + default="multi_process", + choices=["multi_process", "ray"], + help="The backend to use for stage workers.", + ) + omni_config_group.add_argument( + "--ray-address", + type=str, + default=None, + help="The address of the Ray cluster to connect to.", + ) + + # Diffusion model specific arguments + omni_config_group.add_argument( + "--num-gpus", + type=int, + default=None, + help="Number of GPUs to use for diffusion model inference.", + ) + omni_config_group.add_argument( + "--usp", + "--ulysses-degree", + dest="ulysses_degree", + type=int, + default=None, + help="Ulysses Sequence Parallelism degree for diffusion models. " + "Equivalent to setting DiffusionParallelConfig.ulysses_degree.", + ) + omni_config_group.add_argument( + "--ring", + dest="ring_degree", + type=int, + default=None, + help="Ring Sequence Parallelism degree for diffusion models. " + "Equivalent to setting DiffusionParallelConfig.ring_degree.", + ) + + # Cache optimization parameters + omni_config_group.add_argument( + "--cache-backend", + type=str, + default="none", + help="Cache backend for diffusion models, options: 'tea_cache', 'cache_dit'", + ) + omni_config_group.add_argument( + "--cache-config", + type=str, + default=None, + help="JSON string of cache configuration (e.g., '{\"rel_l1_thresh\": 0.2}').", + ) + omni_config_group.add_argument( + "--enable-cache-dit-summary", + action="store_true", + help="Enable cache-dit summary logging after diffusion forward passes.", + ) + + # VAE memory optimization parameters + omni_config_group.add_argument( + "--vae-use-slicing", + action="store_true", + help="Enable VAE slicing for memory optimization (useful for mitigating OOM issues).", + ) + omni_config_group.add_argument( + "--vae-use-tiling", + action="store_true", + help="Enable VAE tiling for memory optimization (useful for mitigating OOM issues).", + ) + + # diffusion model offload parameters + serve_parser.add_argument( + "--enable-cpu-offload", + action="store_true", + help="Enable CPU offloading for diffusion models.", + ) + serve_parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) + serve_parser.add_argument( + "--layerwise-num-gpu-layers", + type=int, + default=1, + help="Number of layers (blocks) to keep on GPU during generation.", + ) + + # Video model parameters (e.g., Wan2.2) - engine-level + omni_config_group.add_argument( + "--boundary-ratio", + type=float, + default=None, + help="Boundary split ratio for low/high DiT in video models (e.g., 0.875 for Wan2.2).", + ) + omni_config_group.add_argument( + "--flow-shift", + type=float, + default=None, + help="Scheduler flow_shift for video models (e.g., 5.0 for 720p, 12.0 for 480p).", + ) + omni_config_group.add_argument( + "--cfg-parallel-size", + type=int, + default=1, + choices=[1, 2], + help="Number of devices for CFG parallel computation for diffusion models. " + "Equivalent to setting DiffusionParallelConfig.cfg_parallel_size.", + ) + + # Default sampling parameters + omni_config_group.add_argument( + "--default-sampling-params", + type=str, + help="Json str for Default sampling parameters, \n" + 'Structure: {"<stage_id>": {<sampling_param>: value, ...}, ...}\n' + 'e.g., \'{"0": {"num_inference_steps":50, "guidance_scale":1}}\'. ' + "Currently only supports diffusion models.", + ) + # Diffusion model mixed precision + omni_config_group.add_argument( + "--max-generated-image-size", + type=float, + help="The max size of generate image (height * width).", + ) + return serve_parser + + +def cmd_init() -> list[CLISubcommand]: + return [OmniServeCommand()] diff --git a/vllm_omni/entrypoints/client_request_state.py b/vllm_omni/entrypoints/client_request_state.py new file mode 100644 index 0000000000000000000000000000000000000000..3e68abb173ef96608cc6440299da88cd7cce6041 --- /dev/null +++ b/vllm_omni/entrypoints/client_request_state.py @@ -0,0 +1,13 @@ +import asyncio + +from vllm_omni.entrypoints.log_utils import OrchestratorMetrics + + +class ClientRequestState: + """Tracks the state of an individual request in the orchestrator.""" + + def __init__(self, request_id: str, queue: asyncio.Queue | None = None): + self.request_id = request_id + self.stage_id: int | None = None + self.queue = queue if queue is not None else asyncio.Queue() + self.metrics: OrchestratorMetrics | None = None diff --git a/vllm_omni/entrypoints/log_utils.py b/vllm_omni/entrypoints/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2031038da38bd98b7530b299e1e92b5d25e7f0d0 --- /dev/null +++ b/vllm_omni/entrypoints/log_utils.py @@ -0,0 +1,591 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pprint import pformat +from typing import Any + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def log_transfer_tx( + from_stage: int, + to_stage: int, + request_id: Any, + size_bytes: int, + tx_time_ms: float, + used_shm: bool, +) -> None: + logger.info( + pformat( + { + "type": "transfer_stats", + "from_stage": from_stage, + "to_stage": to_stage, + "request_id": request_id, + "size_bytes": int(size_bytes), + "tx_time_ms": float(tx_time_ms), + "tx_mbps": (float(size_bytes) * 8.0) / (max(tx_time_ms, 1e-6) * 1000.0), + "used_shm": bool(used_shm), + }, + sort_dicts=False, + ) + ) + + +def log_transfer_rx( + from_stage: int, + to_stage: int, + request_id: Any, + rx_bytes: int, + rx_decode_time_ms: float, + in_flight_time_ms: float, +) -> None: + logger.info( + pformat( + { + "type": "transfer_rx_stats", + "from_stage": from_stage, + "to_stage": to_stage, + "request_id": request_id, + "rx_bytes": int(rx_bytes), + "rx_decode_time_ms": float(rx_decode_time_ms), + "in_flight_time_ms": float(in_flight_time_ms), + "rx_time_per_kb_ms": ( + (float(rx_decode_time_ms) / max(float(rx_bytes) / 1024.0, 1e-6)) if rx_bytes > 0 else 0.0 + ), + }, + sort_dicts=False, + ) + ) + + +def log_transfer_total( + from_stage: int, + to_stage: int, + request_id: Any, + size_bytes: int, + tx_time_ms: float, + in_flight_time_ms: float, + rx_decode_time_ms: float, + total_time_ms: float, +) -> None: + logger.info( + pformat( + { + "type": "transfer_total_stats", + "from_stage": from_stage, + "to_stage": to_stage, + "request_id": request_id, + "size_bytes": int(size_bytes), + "tx_time_ms": float(tx_time_ms), + "in_flight_time_ms": float(in_flight_time_ms), + "rx_decode_time_ms": float(rx_decode_time_ms), + "total_time_ms": float(total_time_ms), + "total_time_per_kb_ms": ( + float(total_time_ms) / max(float(size_bytes) / 1024.0, 1e-6) if size_bytes > 0 else 0.0 + ), + }, + sort_dicts=False, + ) + ) + + +def log_stage_request_stats( + stage_id: int, + request_id: Any, + batch_size: int, + num_tokens_out: int, + stage_gen_time_ms: float, + tokens_per_s: float, + rx_transfer_bytes: int, + rx_decode_time_ms: float, + rx_mbps: float, +) -> None: + logger.info( + pformat( + { + "type": "Request_stage_stats", + "stage_id": stage_id, + "request_id": request_id, + "batch_size": int(batch_size), + "num_tokens_out": int(num_tokens_out), + "stage_gen_time_ms": float(stage_gen_time_ms), + "tokens_per_s": float(tokens_per_s), + "rx_transfer_bytes": int(rx_transfer_bytes), + "rx_decode_time_ms": float(rx_decode_time_ms), + "rx_mbps": float(rx_mbps), + }, + sort_dicts=False, + ) + ) + + +def compute_and_log_stage_request_stats( + stage_id: int, + request_id: Any, + batch_size: int, + num_engine_outputs: int, + stage_gen_time_ms: float, + rx_transfer_bytes: int, + rx_decode_time_ms: float, +) -> None: + """Compute per-request metrics and log them in one call.""" + tokens_per_s = (num_engine_outputs * 1000.0 / stage_gen_time_ms) if stage_gen_time_ms > 0 else 0.0 + rx_mbps = ( + (float(rx_transfer_bytes) * 8.0) / (max(float(rx_decode_time_ms), 1e-6) * 1000.0) + if rx_transfer_bytes > 0 + else 0.0 + ) + log_stage_request_stats( + stage_id, + request_id, + int(batch_size), + int(num_engine_outputs), + float(stage_gen_time_ms), + float(tokens_per_s), + int(rx_transfer_bytes), + float(rx_decode_time_ms), + float(rx_mbps), + ) + + +# ----------------- Aggregation helpers for orchestrator ----------------- + + +def record_stage_metrics( + per_request: dict[str, dict[str, Any]], + stage_req_counts: list[int], + stage_total_time_ms: list[float], + stage_total_tokens: list[int], + stage_id: int, + req_id: Any, + metrics: dict[str, Any], +) -> None: + try: + stage_req_counts[stage_id] += 1 + stage_total_tokens[stage_id] += int(metrics.get("num_tokens_out", 0)) + rid_key = str(req_id) + pr = per_request.setdefault(rid_key, {"stages": {}, "transfers_ms": 0.0, "transfers_bytes": 0}) + pr_stages = pr["stages"] # type: ignore[index] + stage_data: dict[str, Any] = { + "stage_gen_time_ms": float(metrics.get("stage_gen_time_ms", 0.0)), + "num_tokens_out": int(metrics.get("num_tokens_out", 0)), + } + # Only record num_tokens_in for stage 0 (initial prompt) + if stage_id == 0: + stage_data["num_tokens_in"] = int(metrics.get("num_tokens_in", 0)) + stage_total_tokens[stage_id] += int(metrics.get("num_tokens_in", 0)) + pr_stages[stage_id] = stage_data + except Exception: + pass + + +def aggregate_rx_and_maybe_total( + transfer_edge_req: dict[tuple[int, int, str], dict[str, float]], + transfer_agg: dict[tuple[int, int], dict[str, float]], + per_request: dict[str, dict[str, Any]], + stage_id: int, + req_id: Any, + rx_bytes: float, + rx_ms: float, + in_flight_ms: float, +) -> tuple[int, float, float] | None: + try: + # Update RX aggregates for (stage_id-1 -> stage_id) + if stage_id > 0: + key = (stage_id - 1, stage_id) + agg = transfer_agg.get(key) + if agg is None: + agg = { + "sum_bytes": 0.0, + "sum_ms": 0.0, + "count": 0.0, + "sum_rx_bytes": 0.0, + "sum_rx_ms": 0.0, + "rx_count": 0.0, + "sum_total_ms": 0.0, + "total_count": 0.0, + } + transfer_agg[key] = agg + agg["sum_rx_bytes"] += float(rx_bytes) + agg["sum_rx_ms"] += float(rx_ms) + agg["rx_count"] += 1.0 + + # Try combine with sender-side timing if present + rid_key = str(req_id) + s = transfer_edge_req.get((stage_id - 1, stage_id, rid_key)) + if s is None: + return None + tx_ms = float(s.get("tx_ms", 0.0)) + size_b = float(s.get("size_bytes", rx_bytes)) + total_ms = tx_ms + float(in_flight_ms) + float(rx_ms) + agg["sum_total_ms"] += total_ms + agg["total_count"] += 1.0 + # accumulate per-request transfer totals + try: + pr = per_request.setdefault(rid_key, {"stages": {}, "transfers_ms": 0.0, "transfers_bytes": 0}) + pr["transfers_ms"] = float(pr.get("transfers_ms", 0.0)) + total_ms # type: ignore[index] + pr["transfers_bytes"] = int(pr.get("transfers_bytes", 0)) + int(rx_bytes) # type: ignore[index] + except Exception: + pass + return int(size_b), float(tx_ms), float(total_ms) + return None + except Exception: + return None + + +def record_sender_transfer_agg( + transfer_agg: dict[tuple[int, int], dict[str, float]], + transfer_edge_req: dict[tuple[int, int, str], dict[str, float]], + from_stage: int, + to_stage: int, + req_id: Any, + size_bytes: int, + tx_ms: float, +) -> None: + try: + key = (from_stage, to_stage) + agg = transfer_agg.get(key) + if agg is None: + agg = { + "sum_bytes": 0.0, + "sum_ms": 0.0, + "count": 0.0, + "sum_rx_bytes": 0.0, + "sum_rx_ms": 0.0, + "rx_count": 0.0, + "sum_total_ms": 0.0, + "total_count": 0.0, + } + transfer_agg[key] = agg + agg["sum_bytes"] += float(size_bytes) + agg["sum_ms"] += float(tx_ms) + agg["count"] += 1.0 + # Store sender-side timing for per-request combination + rid_key = str(req_id) + transfer_edge_req[(from_stage, to_stage, rid_key)] = { + "tx_ms": float(tx_ms), + "size_bytes": float(size_bytes), + } + except Exception: + pass + + +def count_tokens_from_outputs(engine_outputs: list[Any]) -> int: + total = 0 + for _ro in engine_outputs: + try: + outs = getattr(_ro, "outputs", None) + if outs and len(outs) > 0: + tokens = getattr(outs[0], "token_ids", None) + if tokens is not None: + total += len(tokens) + except Exception: + pass + return total + + +def build_stage_summary( + stage_req_counts: list[int], + stage_total_tokens: list[int], + stage_total_time_ms: list[float], +) -> list[dict[str, Any]]: + summary: list[dict[str, Any]] = [] + for sid in range(len(stage_req_counts)): + reqs = stage_req_counts[sid] + tokens = stage_total_tokens[sid] + total_ms = float(stage_total_time_ms[sid]) + avg_req = (total_ms / reqs) if reqs > 0 else 0.0 + avg_tok = (tokens * 1000.0 / total_ms) if total_ms > 0 else 0.0 + summary.append( + { + "stage_id": sid, + "requests": int(reqs), + "tokens": int(tokens), + "total_time_ms": total_ms, + "avg_time_per_request_ms": avg_req, + "avg_tokens_per_s": avg_tok, + } + ) + return summary + + +def build_transfer_summary( + transfer_agg: dict[tuple[int, int], dict[str, float]], +) -> list[dict[str, Any]]: + summary: list[dict[str, Any]] = [] + for (src, dst), agg in transfer_agg.items(): + sum_bytes = float(agg.get("sum_bytes", 0.0)) + sum_ms = float(agg.get("sum_ms", 0.0)) + samples = int(agg.get("count", 0.0)) + tx_mbps = (sum_bytes * 8.0) / (max(sum_ms, 1e-6) * 1000.0) if sum_bytes > 0 else 0.0 + sum_rx_bytes = float(agg.get("sum_rx_bytes", 0.0)) + sum_rx_ms = float(agg.get("sum_rx_ms", 0.0)) + samples_rx = int(agg.get("rx_count", 0.0)) + rx_mbps = (sum_rx_bytes * 8.0) / (max(sum_rx_ms, 1e-6) * 1000.0) if sum_rx_bytes > 0 else 0.0 + sum_total_ms = float(agg.get("sum_total_ms", 0.0)) + samples_total = int(agg.get("total_count", 0.0)) + total_mbps = (sum_bytes * 8.0) / (max(sum_total_ms, 1e-6) * 1000.0) if sum_bytes > 0 else 0.0 + summary.append( + { + "from_stage": src, + "to_stage": dst, + "samples": samples, + "total_bytes": int(sum_bytes), + "total_time_ms": sum_ms, + "tx_mbps": tx_mbps, + "rx_samples": samples_rx, + "rx_total_bytes": int(sum_rx_bytes), + "rx_total_time_ms": sum_rx_ms, + "rx_mbps": rx_mbps, + "total_samples": samples_total, + "total_transfer_time_ms": sum_total_ms, + "total_mbps": total_mbps, + } + ) + return summary + + +@dataclass +class StageStats: + total_token: int + total_gen_time: float + + +@dataclass +class StageRequestMetrics: + num_tokens_in: int + num_tokens_out: int + stage_gen_time_ms: float + batch_id: int + batch_size: int + rx_decode_time_ms: float + rx_transfer_bytes: int + rx_in_flight_time_ms: float + + stage_stats: StageStats + + +class OrchestratorMetrics: + def __init__( + self, + num_stages: int, + enable_stats: bool, + wall_start_ts: float, + ) -> None: + self.num_stages = int(num_stages) + self.enable_stats = bool(enable_stats) + self.stage_total_time_ms: list[float] = [0.0 for _ in range(self.num_stages)] + self.stage_total_tokens: list[int] = [0 for _ in range(self.num_stages)] + self.stage_req_counts: list[int] = [0 for _ in range(self.num_stages)] + self.transfer_agg: dict[tuple[int, int], dict[str, float]] = {} + self.transfer_edge_req: dict[tuple[int, int, str], dict[str, float]] = {} + self.e2e_total_ms: float = 0.0 + self.e2e_total_tokens: int = 0 + self.e2e_count: int = 0 + self.e2e_done: set[str] = set() + self.per_request: dict[str, dict[str, Any]] = {} + self.sum_per_request_transfer_ms: float = 0.0 + self.wall_start_ts: float = float(wall_start_ts) + self.last_finish_ts: float = float(wall_start_ts) + self.stage_seen_batches: dict[int, set] = {sid: set() for sid in range(self.num_stages)} + self.stage_first_ts: list[float | None] = [None for _ in range(self.num_stages)] + self.stage_last_ts: list[float | None] = [None for _ in range(self.num_stages)] + + def on_stage_metrics(self, stage_id: int, req_id: Any, metrics: dict[str, Any]) -> None: + record_stage_metrics( + self.per_request, + self.stage_req_counts, + self.stage_total_time_ms, + self.stage_total_tokens, + stage_id, + req_id, + metrics, + ) + if self.enable_stats: + compute_and_log_stage_request_stats( + stage_id=stage_id, + request_id=req_id, + batch_size=metrics.get("batch_size"), + num_engine_outputs=metrics.get("num_tokens_out"), + stage_gen_time_ms=metrics.get("stage_gen_time_ms"), + rx_decode_time_ms=metrics.get("rx_decode_time_ms"), + rx_transfer_bytes=metrics.get("rx_transfer_bytes"), + ) + if stage_stats := metrics.get("stage_stats", None): + total_token = int(stage_stats.get("total_token")) + total_gen_time = float(stage_stats.get("total_gen_time")) + _avg_tokens_per_s = (total_token * 1000.0 / total_gen_time) if total_gen_time > 0 else 0.0 + logger.info( + pformat( + { + "type": "Stage_running_avg", + "stage_id": stage_id, + "total_tokens": total_token, + "total_gen_time_ms": total_gen_time, + "avg_tokens_per_s": _avg_tokens_per_s, + }, + sort_dicts=False, + ) + ) + try: + batch_id_raw = metrics.get("batch_id", None) + if batch_id_raw is not None: + batch_id = int(batch_id_raw) + if batch_id not in self.stage_seen_batches[stage_id]: + self.stage_total_time_ms[stage_id] += float(metrics.get("stage_gen_time_ms", 0.0)) + self.stage_seen_batches[stage_id].add(batch_id) + except Exception: + pass + rx_b = float(metrics.get("rx_transfer_bytes", 0.0)) + rx_ms = float(metrics.get("rx_decode_time_ms", 0.0)) + in_flight_ms = float(metrics.get("rx_in_flight_time_ms", 0.0)) + combined = aggregate_rx_and_maybe_total( + self.transfer_edge_req, + self.transfer_agg, + self.per_request, + stage_id, + req_id, + rx_b, + rx_ms, + in_flight_ms, + ) + if self.enable_stats and stage_id > 0: + log_transfer_rx( + stage_id - 1, + stage_id, + req_id, + int(rx_b), + float(rx_ms), + float(in_flight_ms), + ) + if combined is not None: + size_b_c, tx_ms_c, total_ms_c = combined + log_transfer_total( + stage_id - 1, + stage_id, + req_id, + int(size_b_c), + float(tx_ms_c), + float(in_flight_ms), + float(rx_ms), + float(total_ms_c), + ) + + def on_forward( + self, + from_stage: int, + to_stage: int, + req_id: Any, + size_bytes: int, + tx_ms: float, + used_shm: bool, + ) -> None: + # Mark first input time for the destination stage if not set + if self.stage_first_ts[to_stage] is None: + self.stage_first_ts[to_stage] = time.time() + if self.enable_stats: + log_transfer_tx( + from_stage, + to_stage, + req_id, + int(size_bytes), + float(tx_ms), + bool(used_shm), + ) + record_sender_transfer_agg( + self.transfer_agg, + self.transfer_edge_req, + from_stage, + to_stage, + req_id, + int(size_bytes), + float(tx_ms), + ) + + def on_finalize_request( + self, + stage_id: int, + req_id: Any, + req_start_ts: float, + ) -> None: + rid_key = str(req_id) + _t0 = float(req_start_ts) + _t1 = time.time() + # Update last output time for this stage + prev_last = self.stage_last_ts[stage_id] + self.stage_last_ts[stage_id] = _t1 if prev_last is None else max(prev_last, _t1) + self.last_finish_ts = max(self.last_finish_ts, _t1) + e2e_ms = (_t1 - _t0) * 1000.0 + + # Sum tokens from all stages for this request + # Include input tokens from stage 0 + output tokens from all stages + pr = self.per_request.setdefault(rid_key, {"stages": {}, "transfers_ms": 0.0, "transfers_bytes": 0}) + total_tokens = 0 + stages_info = pr.get("stages", {}) + for sid, stage_data in stages_info.items(): + # Add input tokens only from stage 0 (initial prompt) + if sid == 0: + total_tokens += int(stage_data.get("num_tokens_in", 0)) + total_tokens += int(stage_data.get("num_tokens_out", 0)) + + self.e2e_total_ms += e2e_ms + self.e2e_total_tokens += total_tokens + self.e2e_count += 1 + self.e2e_done.add(rid_key) + per_req_record = { + "type": "request_level_metrics", + "request_id": rid_key, + "e2e_time_ms": e2e_ms, + "e2e_tpt": (e2e_ms / total_tokens) if total_tokens > 0 else 0.0, + "e2e_total_tokens": total_tokens, + "transfers_total_time_ms": float(pr.get("transfers_ms", 0.0)), + "transfers_total_bytes": int(pr.get("transfers_bytes", 0)), + "stages": stages_info, + } + self.sum_per_request_transfer_ms += float(pr.get("transfers_ms", 0.0)) + logger.info(pformat(per_req_record, sort_dicts=False)) + + def build_and_log_summary(self, final_stage_id_to_prompt: dict[str, int]) -> dict[str, Any]: + # Compute stage summary using wall time between first input and last output per stage + stage_summary: list[dict[str, Any]] = [] + for sid in range(self.num_stages): + first_ts = self.stage_first_ts[sid] + last_ts = self.stage_last_ts[sid] + total_ms = ( + (max(0.0, (last_ts - first_ts)) * 1000.0) if (first_ts is not None and last_ts is not None) else 0.0 + ) + reqs = self.stage_req_counts[sid] + tokens = self.stage_total_tokens[sid] + avg_req = (total_ms / reqs) if reqs > 0 else 0.0 + avg_tok = (tokens * 1000.0 / total_ms) if total_ms > 0 else 0.0 + stage_summary.append( + { + "stage_id": sid, + "requests": int(reqs), + "tokens": int(tokens), + "total_time_ms": float(total_ms), + "avg_time_per_request_ms": float(avg_req), + "avg_tokens_per_s": float(avg_tok), + } + ) + transfer_summary = build_transfer_summary(self.transfer_agg) + e2e_avg_req = (self.e2e_total_ms / self.e2e_count) if self.e2e_count > 0 else 0.0 + e2e_avg_tok = (self.e2e_total_tokens * 1000.0 / self.e2e_total_ms) if self.e2e_total_ms > 0 else 0.0 + wall_time_ms = max(0.0, (self.last_finish_ts - self.wall_start_ts) * 1000.0) + summary: dict[str, Any] = { + "e2e_requests": int(self.e2e_count), + "e2e_total_time_ms": float(wall_time_ms), + "e2e_sum_time_ms": float(self.e2e_total_ms), + "e2e_total_tokens": int(self.e2e_total_tokens), + "e2e_avg_time_per_request_ms": e2e_avg_req, + "e2e_avg_tokens_per_s": e2e_avg_tok, + "wall_time_ms": wall_time_ms, + "final_stage_id": final_stage_id_to_prompt, + "stages": stage_summary, + "transfers": transfer_summary, + } + return summary diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py new file mode 100644 index 0000000000000000000000000000000000000000..97357dc3b33297aa0758c8fbb35e11186d1bae70 --- /dev/null +++ b/vllm_omni/entrypoints/omni.py @@ -0,0 +1,866 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import multiprocessing as mp +import os +import time +import uuid +import weakref +from collections.abc import Callable, Generator, Sequence +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict +from pprint import pformat +from typing import Any, Literal, overload + +from omegaconf import OmegaConf +from tqdm.auto import tqdm +from vllm import SamplingParams +from vllm.logger import init_logger + +from vllm_omni.distributed.omni_connectors import ( + get_stage_connector_config, + initialize_orchestrator_connectors, +) +from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector +from vllm_omni.distributed.omni_connectors.utils.initialization import ( + resolve_omni_kv_config_for_stage, +) +from vllm_omni.distributed.ray_utils.utils import ( + create_placement_group, + get_ray_queue_class, + try_close_ray, +) +from vllm_omni.entrypoints.log_utils import OrchestratorMetrics +from vllm_omni.entrypoints.omni_stage import OmniStage +from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType +from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load +from vllm_omni.entrypoints.utils import ( + get_final_stage_id_for_e2e, + inject_omni_kv_config, + load_stage_configs_from_model, + load_stage_configs_from_yaml, + resolve_model_config_path, +) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + + +def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): + """Weak reference cleanup function for OmniBase instances.""" + if stage_list: + for q in stage_in_queues: + try: + q.put_nowait(SHUTDOWN_TASK) + except Exception as e: + logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") + for stage in stage_list: + try: + stage.stop_stage_worker() + except Exception as e: + logger.warning(f"Failed to stop stage worker: {e}") + try_close_ray(ray_pg) + + +def _dummy_snapshot_download(model_id): + return model_id + + +def omni_snapshot_download(model_id) -> str: + # TODO: this is just a workaround for quickly use modelscope, we should support + # modelscope in weight loading feature instead of using `snapshot_download` + if os.environ.get("VLLM_USE_MODELSCOPE", False): + from modelscope.hub.snapshot_download import snapshot_download + + return snapshot_download(model_id) + else: + return _dummy_snapshot_download(model_id) + + +class OmniBase: + """Base class for serving Omni models. + + Args: + model: Model name or path to load. + **kwargs: Arbitrary keyword arguments. + - stage_configs_path: Optional path to YAML file containing stage + configurations. If None, configurations are loaded from the model. + - log_stats: Whether to enable statistics logging + be written to files with stage-specific suffixes. + - stage_init_timeout: Per-stage init watchdog (seconds). Measured from + when the previous stage finished (possibly a prior Omni run with GPU + reuse/overlap) to when the current stage starts to initialize. + - shm_threshold_bytes: Threshold in bytes for using shared memory + for IPC. Objects larger than this threshold will use shared memory. + - worker_backend: Backend for worker processes. Default is "multi_process". + - ray_address: Address of Ray cluster for Ray backend, if using Ray backend. + - batch_timeout: Timeout in seconds for batching requests within a stage + - init_timeout: Timeout in seconds for waiting for all stages to initialize + - Additional keyword arguments passed to stage engines. + """ + + def __init__(self, model: str, **kwargs: Any) -> None: + model = omni_snapshot_download(model) + kwargs["model"] = model + + # Stage management attributes + self.stage_list: list[OmniStage] = [] + self._stage_in_queues: list[mp.Queue] = [] + self._stage_out_queues: list[mp.Queue] = [] + self._stages_ready: set[int] = set() + self._ray_pg = None + self._queue_cls = None + self._ctx = None + + # Initialize stages - each stage will create appropriate instance based on stage_type + # Stage workers will automatically create OmniLLM or OmniDiffusion instances + # based on stage_type in YAML config (handled in omni_stage.py) + logger.info(f"Initializing stages for model: {model}") + self._initialize_stages(model, kwargs) + + def _get_default_cache_config(self, cache_backend: str | None) -> dict[str, Any] | None: + if cache_backend == "cache_dit": + return { + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.24, + "max_continuous_cached_steps": 3, + "enable_taylorseer": False, + "taylorseer_order": 1, + "scm_steps_mask_policy": None, + "scm_steps_policy": "dynamic", + } + if cache_backend == "tea_cache": + return { + "rel_l1_thresh": 0.2, + } + return None + + def _normalize_cache_config(self, cache_backend: str | None, cache_config: Any | None) -> Any | None: + if isinstance(cache_config, str): + try: + cache_config = json.loads(cache_config) + except json.JSONDecodeError: + logger.warning("Invalid cache_config JSON, using defaults.") + cache_config = None + if cache_config is None and cache_backend not in (None, "", "none"): + cache_config = self._get_default_cache_config(cache_backend) + return cache_config + + def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """Create default diffusion stage configuration.""" + # We temporally create a default config for diffusion stage. + # In the future, we should merge the default config with the user-provided config. + # TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error. + if "dtype" in kwargs: + kwargs["dtype"] = str(kwargs["dtype"]) + cache_backend = kwargs.get("cache_backend", "none") + cache_config = self._normalize_cache_config(cache_backend, kwargs.get("cache_config", None)) + # TODO: hack, calculate devices based on parallel config. + devices = "0" + if "parallel_config" in kwargs: + num_devices = kwargs["parallel_config"].world_size + for i in range(1, num_devices): + devices += f",{i}" + default_stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": { + "process": True, + "devices": devices, + "max_batch_size": 1, + }, + "engine_args": OmegaConf.create( + { + **kwargs, + "cache_backend": cache_backend, + "cache_config": cache_config, + } + ), + "final_output": True, + "final_output_type": "image", + } + ] + default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" + return default_stage_cfg + + def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: + """Initialize stage list management.""" + stage_init_timeout = kwargs.get("stage_init_timeout", 20) + shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) + init_timeout = kwargs.get("init_timeout", 300) + worker_backend = kwargs.get("worker_backend", "multi_process") + ray_address = kwargs.get("ray_address", None) + batch_timeout = kwargs.get("batch_timeout", 10) + stage_configs_path = kwargs.get("stage_configs_path", None) + log_stats = kwargs.get("log_stats", False) + + ### base engine args + tokenizer = kwargs.get("tokenizer", None) + + base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None + + # Load stage configurations from YAML + if stage_configs_path is None: + self.config_path = resolve_model_config_path(model) + self.stage_configs = load_stage_configs_from_model(model, base_engine_args=base_engine_args) + if not self.stage_configs: + default_stage_cfg = self._create_default_diffusion_stage_cfg(kwargs) + self.stage_configs = OmegaConf.create(default_stage_cfg) + else: + self.config_path = stage_configs_path + self.stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=base_engine_args) + + # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. + for cfg in self.stage_configs: + try: + if getattr(cfg, "stage_type", None) != "diffusion": + continue + if not hasattr(cfg, "engine_args") or cfg.engine_args is None: + cfg.engine_args = OmegaConf.create({}) + if kwargs.get("lora_path") is not None: + if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None: + cfg.engine_args.lora_path = kwargs["lora_path"] + lora_scale = kwargs.get("lora_scale") + if lora_scale is None: + # Backwards compatibility for older callers. + lora_scale = kwargs.get("static_lora_scale") + if lora_scale is not None: + if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: + cfg.engine_args.lora_scale = lora_scale + except Exception as e: + logger.warning("Failed to inject LoRA config for stage: %s", e) + + # Initialize connectors + self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( + self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes + ) + + # Initialize stats paths + self._enable_stats: bool = bool(log_stats) + + self.worker_backend = worker_backend + self.ray_address = ray_address + self.batch_timeout = batch_timeout + # async chunk remains the same for each stage + self.async_chunk = self._is_async_chunk_enable(self.stage_configs) + + # Build OmniStage instances in parallel, preserve original order + def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: + idx, cfg = idx_cfg + return idx, OmniStage(cfg, stage_init_timeout=stage_init_timeout) + + with ThreadPoolExecutor(max_workers=min(len(self.stage_configs), max(1, os.cpu_count() or 1))) as executor: + futures = [executor.submit(_build_stage, (idx, cfg)) for idx, cfg in enumerate(self.stage_configs)] + results: list[tuple[int, OmniStage]] = [] + for fut in as_completed(futures): + results.append(fut.result()) + results.sort(key=lambda x: x[0]) + self.stage_list = [st for _, st in results] + self.default_sampling_params_list = [st.default_sampling_params for st in self.stage_list] + self.output_modalities = [st.final_output_type for st in self.stage_list] + logger.debug(f"[{self._name}] Loaded {len(self.stage_list)} stages") + + if self.worker_backend == "ray": + self._queue_cls = get_ray_queue_class() + else: + self._ctx = mp.get_context("spawn") + self._queue_cls = lambda: self._ctx.Queue(maxsize=0) + + self._stage_init_timeout = max(0, int(stage_init_timeout)) + self._shm_threshold_bytes = max(0, int(shm_threshold_bytes)) + self._start_stages(model) + # Wait for all stages to report readiness before seeding + self._wait_for_stages_ready(timeout=init_timeout) + + def _is_async_chunk_enable(self, stage_args: list) -> bool: + """get async chunk flag""" + engine_args = getattr(stage_args[0], "engine_args", None) + return bool(getattr(engine_args, "async_chunk", False)) + + def _start_stages(self, model: str) -> None: + """Start all stage processes.""" + if self.worker_backend == "ray": + # Initialize Ray Cluster + self._ray_pg = create_placement_group( + number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" + ) + + for stage_id, stage in enumerate[OmniStage](self.stage_list): + in_q = self._queue_cls() + out_q = self._queue_cls() + self._stage_in_queues.append(in_q) + self._stage_out_queues.append(out_q) + stage.attach_queues(in_q, out_q) + + stage_connectors_config = get_stage_connector_config( + self.omni_transfer_config, + stage_id, + ) + + # Inject YAML-resolved connector config into omni_kv_config for + # in-engine usage (GPU model runner reads model_config.omni_kv_config). + try: + omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage( + self.omni_transfer_config, stage_id + ) + if omni_conn_cfg: + inject_omni_kv_config(stage, omni_conn_cfg, omni_from, omni_to) # type: ignore + + except Exception as e: + logger.debug("[Omni] Failed to inject omni connector config into stage-%s: %s", stage_id, e) + + stage.init_stage_worker( + model, + is_async=self.is_async, + shm_threshold_bytes=self._shm_threshold_bytes, + ctx=self._ctx if self.worker_backend != "ray" else None, + batch_timeout=self.batch_timeout, + connectors_config=stage_connectors_config, + worker_backend=self.worker_backend, + ray_placement_group=self._ray_pg, + ) + + logger.debug(f"[{self._name}] Stage-{stage_id} process started") + + def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str, Any]) -> None: + self._stages_ready.add(stage_id) + logger.info(f"[{self._name}] Stage-{stage_id} reported ready") + + def _wait_for_stages_ready(self, timeout: int = 120) -> None: + """Wait for all stages to report readiness with optimized polling.""" + num_stages = len(self.stage_list) + deadline = time.time() + max(0, int(timeout)) + + logger.info(f"[{self._name}] Waiting for {num_stages} stages to initialize (timeout: {timeout}s)") + + while len(self._stages_ready) < num_stages and time.time() < deadline: + progressed = False + for stage_id, stage in enumerate(self.stage_list): + if stage_id in self._stages_ready: + continue + + # Check if the stage has reported status + if result := stage.try_collect(): + progressed = True + if result.get("type") == "stage_ready": + self._process_stage_ready(stage, stage_id, result) + + if not progressed: + time.sleep(0.05) + + # Handle Final State + if len(self._stages_ready) == num_stages: + logger.info(f"[{self._name}] All stages initialized successfully") + return + + # Handle Timeout/Failure + not_ready = sorted(set(range(num_stages)) - set(self._stages_ready)) + logger.warning( + f"[{self._name}] Initialization timeout: {len(self._stages_ready)}/{num_stages} " + f"stages ready. Missing stages: {not_ready}" + ) + + suggestions = [ + f"Ignore this warning if the model weight download / load from disk time is longer than {timeout}s.", + "Verify GPU/device assignment in config (runtime.devices) is correct.", + "Check GPU/host memory availability; reduce model or batch size if needed.", + "Check model weights path and network reachability (if loading remotely).", + "Increase initialization wait time (stage_init_timeout or call-site timeout).", + ] + + formatted_suggestions = "\n".join(f" {i + 1}) {msg}" for i, msg in enumerate(suggestions)) + + logger.warning(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") + + def start_profile(self, stages: list[int] | None = None) -> None: + """Start profiling for specified stages. + + Sends start_profile command to stage workers. Profiling must be enabled + via VLLM_TORCH_PROFILER_DIR environment variable. + + Args: + stages: List of stage IDs to start profiling. If None, starts + profiling for all stages that have profiling enabled. + + Example: + >>> # Profile all stages + >>> omni.start_profile() + >>> outputs = omni.generate(prompts, sampling_params) + >>> omni.stop_profile() + + >>> # Profile only stage 0 and 2 + >>> omni.start_profile(stages=[0, 2]) + """ + if stages is None: + stages = list(range(len(self.stage_list))) + + for stage_id in stages: + if stage_id < len(self.stage_list): + try: + self.stage_list[stage_id].submit({"type": OmniStageTaskType.PROFILER_START}) + logger.info("[%s] Sent start_profile to stage-%s", self._name, stage_id) + except Exception as e: + logger.warning( + "[%s] Failed to send start_profile to stage-%s: %s", + self._name, + stage_id, + e, + ) + + def stop_profile(self, stages: list[int] | None = None) -> dict: + """ + Synchronously stop profiling for specified stages and collect + the file paths for traces and tables. + """ + if stages is None: + stages = list(range(len(self.stage_list))) + + all_results = {"traces": [], "tables": []} + + for stage_id in stages: + if stage_id < len(self.stage_list): + stage = self.stage_list[stage_id] + + # Check if the stage object has our new bridge method + if hasattr(stage, "stop_profile"): + logger.info("[%s] Requesting profile data collection from stage-%s", self._name, stage_id) + + # This is the blocking call that triggers the RPC chain + stage_data = stage.stop_profile() + + if isinstance(stage_data, dict): + # FIX: Handle both single key and list key formats + traces = stage_data.get("trace") or stage_data.get("traces") + tables = stage_data.get("table") or stage_data.get("tables") + + # Debug logging + logger.debug(f"[{self._name}] Stage-{stage_id} returned: {stage_data.keys()}") + if traces: + logger.debug(f"[{self._name}] Stage-{stage_id} traces type: {type(traces)}") + if tables: + logger.debug(f"[{self._name}] Stage-{stage_id} tables type: {type(tables)}") + + # Handle single strings + if traces: + if isinstance(traces, str): + all_results["traces"].append(traces) + elif isinstance(traces, list): + all_results["traces"].extend(traces) + + # Handle single strings + if tables: + if isinstance(tables, str): + all_results["tables"].append(tables) + elif isinstance(tables, list): + all_results["tables"].extend(tables) + else: + logger.warning(f"[{self._name}] Stage-{stage_id} returned no table data") + else: + logger.warning(f"[{self._name}] Stage-{stage_id} returned non-dict data: {type(stage_data)}") + else: + # Fallback for non-diffusion stages + logger.warning( + "[%s] Stage-%s does not support synchronous stop_profile. Falling back to async.", + self._name, + stage_id, + ) + stage.submit({"type": OmniStageTaskType.PROFILER_STOP}) + + # Final debug output + logger.info( + f"[{self._name}] Collected {len(all_results['traces'])} trace(s) and {len(all_results['tables'])} table(s)" + ) + + return all_results + + def close(self) -> None: + """Close all stage processes and clean up resources.""" + if hasattr(self, "_weak_finalizer"): + self._weak_finalizer() + + @property + def _name(self) -> str: + return "OmniBase" + + @property + def is_async(self) -> bool: + return False + + +class Omni(OmniBase): + """Unified entrypoint for both LLM and Diffusion models for better usability. + + Args: + model: Model name or path to load. + **kwargs: Arbitrary keyword arguments. + - stage_configs_path: Optional path to YAML file containing stage + configurations. If None, configurations are loaded from the model. + - log_stats: Whether to enable statistics logging + be written to files with stage-specific suffixes. + - stage_init_timeout: Per-stage init watchdog (seconds). Measured from + when the previous stage finished (possibly a prior Omni run with GPU + reuse/overlap) to when the current stage starts to initialize. + - shm_threshold_bytes: Threshold in bytes for using shared memory + for IPC. Objects larger than this threshold will use shared memory. + - worker_backend: Backend for worker processes. Default is "multi_process". + - ray_address: Address of Ray cluster for Ray backend, if using Ray backend. + - batch_timeout: Timeout in seconds for batching requests within a stage + - init_timeout: Timeout in seconds for waiting for all stages to initialize + - Additional keyword arguments passed to stage engines. + + Example: + >>> omni = Omni(model="Qwen/Qwen2.5-Omni-7B") + >>> outputs = omni.generate(prompts="Hello, world!", sampling_params_list=[SamplingParams()]) + >>> print(outputs) + """ + + def __init__(self, model: str, **kwargs: Any) -> None: + super().__init__(model, **kwargs) + + # Register weak reference cleanup (called on garbage collection) + self._weak_finalizer = weakref.finalize( + self, + _weak_close_cleanup, + self.stage_list, + self._stage_in_queues, + self._ray_pg, + ) + + @overload + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: Literal[True], + ) -> Generator[OmniRequestOutput, None, None]: ... + + @overload + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: Literal[False] = False, + ) -> list[OmniRequestOutput]: ... + + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: bool = False, + use_tqdm: bool | Callable[..., tqdm] = True, + ) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]: + """Generate outputs for the given prompts. + + Orchestrates the multi-stage pipeline based on YAML configuration. + Each stage will use OmniLLM or OmniDiffusion based on stage_type. + + Args: + prompts: Input prompt(s) for generation. + sampling_params_list: Optional list of per-stage parameters. + py_generator: Whether the returned result(s) are wrapped in a generator instead of a list. + use_tqdm: Whether to use tqdm progress bar + + Returns: + List of OmniRequestOutput objects, one for each input prompt. + Each output contains the stage_id, final_output_type, and + the request_output from the final stage. + + Raises: + ValueError: If sampling_params_list is None or has incorrect length. + """ + if sampling_params_list is None: + sampling_params_list = self.default_sampling_params_list + elif not isinstance(sampling_params_list, Sequence): + # TODO: After the recent introduction of BAGEL model (one LLM and one Diffusion), + # expect the text_to_image example code to run when only passing one OmniDiffusionSamplingParams + # This behavior may be confusing, and future PR can improve it. + per_stage_params: list[OmniSamplingParams] = [] + for default_stage_sp in self.default_sampling_params_list: + default_sp_type = default_stage_sp.__class__ + if default_sp_type == sampling_params_list.__class__: + per_stage_params.append(sampling_params_list) + else: + per_stage_params.append(default_stage_sp) + sampling_params_list = per_stage_params + + try: + if py_generator: + return self._run_generation_with_generator(prompts, sampling_params_list) + else: + outputs = list(self._run_generation(prompts, sampling_params_list, use_tqdm)) + return outputs + except Exception as e: + logger.exception("[Orchestrator] Failed to run generation: %s", e) + # Always close on exception to ensure cleanup + self.close() + raise e + + def _run_generation_with_generator( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: Sequence[OmniSamplingParams], + ) -> Generator[OmniRequestOutput, None, None]: + """Run generation through all stages in the pipeline and return a generator.""" + gen = self._run_generation(prompts, sampling_params_list) + try: + yield from gen + except Exception as e: + logger.exception("[Orchestrator] Failed to run generation: %s", e) + raise e + finally: + # Cleanup when generator is exhausted or closed + self.close() + + def _run_generation( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: Sequence[OmniSamplingParams], + use_tqdm: bool | Callable[..., tqdm] = True, + ) -> Generator[OmniRequestOutput, None, None]: + """Run generation through all stages in the pipeline.""" + logger.debug(f"[{self._name}] generate() called") + if sampling_params_list is None: + raise ValueError("sampling_params_list is required for pipelined generation") + + if len(sampling_params_list) != len(self.stage_list): + raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + + for i, (stage, sp) in enumerate(zip(self.stage_list, sampling_params_list)): + ExpectedSPType = OmniDiffusionSamplingParams if stage.stage_type == "diffusion" else SamplingParams + if not isinstance(sp, ExpectedSPType): + raise ValueError( + f"Expected sampling parameters with type {ExpectedSPType} in stage {i}, got {sp.__class__}" + ) + + # Normalize prompts to a list for per-request iteration + # str is also Sequence but only test list-like containers here + if isinstance(prompts, str) or not isinstance(prompts, Sequence): + request_prompts: list[OmniPromptType] = [prompts] + else: + request_prompts = list(prompts) + + # Orchestrator keeps stage objects for input derivation + num_stages = len(self.stage_list) + + # Generate globally unique request IDs and map them to original prompts + request_ids = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] + request_id_to_prompt = {rid: p for rid, p in zip(request_ids, request_prompts)} + + # Track per-request start time for end-to-end timing + _req_start_ts: dict[str, float] = {} + _wall_start_ts: float = time.time() + + # Determine the final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) + final_stage_id_to_prompt: dict[str, int] = {} + for rid, prompt in request_id_to_prompt.items(): + if isinstance(prompt, dict): + prompt_modalities = prompt.get("modalities", None) + else: + prompt_modalities = None + final_stage_id_for_e2e = get_final_stage_id_for_e2e( + prompt_modalities, self.output_modalities, self.stage_list + ) + final_stage_id_to_prompt[rid] = final_stage_id_for_e2e + + # Metrics/aggregation helper + metrics = OrchestratorMetrics( + num_stages, + self._enable_stats, + _wall_start_ts, + ) + + it = request_id_to_prompt.items() + if use_tqdm: + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + it = tqdm_func(it, desc="Adding requests") + + # Seed stage-0 queue with all requests + logger.debug(f"[{self._name}] Seeding {len(request_prompts)} requests into stage-0") + # Mark first input time for stage-0 + metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() + + for req_id, prompt in request_id_to_prompt.items(): + sp0 = sampling_params_list[0] # type: ignore[index] + task = { + "request_id": req_id, + "engine_inputs": prompt, + "sampling_params": sp0, + } + self.stage_list[0].submit(task) + _req_start_ts[req_id] = time.time() + logger.debug(f"[{self._name}] Enqueued request {req_id} to stage-0") + + pbar = None + if use_tqdm: + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + pbar = tqdm_func( + total=len(request_prompts), + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} unit/s, output: {0:.2f} unit/s"), + ) + # For each stage, forward results to next stage; collect finals at the end + # We pipeline by continually polling output queues in stage order + remaining_by_stage: list[int] = [len(request_prompts)] + [0] * (num_stages - 1) + completed_requests = 0 + total_requests = len(request_prompts) + + logger.debug( + f"[{self._name}] Entering scheduling loop: total_requests={total_requests}, stages={num_stages}", + ) + while completed_requests < total_requests: + made_progress = False + for stage_id, stage in enumerate(self.stage_list): + result = stage.try_collect() + if result is None: + continue + + made_progress = True + req_id = result.get("request_id") + if "error" in result: + logger.error( + f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", + ) + continue + + if result.get("type") == "stage_ready": + # Only happens when stage is initialized slower than expected, + # so we wait for a short time and try again + time.sleep(0.05) + continue + + engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + # Mark last output time for this stage whenever we receive outputs + metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) + try: + _m = result.get("metrics") + if _m is not None: + if not isinstance(_m, dict): + _m = asdict(_m) + metrics.on_stage_metrics(stage_id, req_id, _m) + if pbar: + elapsed = pbar.format_dict["elapsed"] or 1e-6 + # Aggregate total tokens/images across all stages + total_out = sum(metrics.stage_total_tokens) + out_spd = total_out / elapsed + + modality = self.output_modalities[stage_id] + unit = "img" if modality == "image" else "tok" + + # Pre-calculate for cleaner string formatting + if metrics.e2e_count > 0: + avg_lat = metrics.e2e_total_ms / metrics.e2e_count + else: + avg_lat = 0 + + # Align with vLLM's wording "est. speed" using multi-line parentheses + pbar.postfix = ( + f"est. speed stage-{stage_id} {unit}/s: {out_spd:.2f}, avg e2e_lat: {avg_lat:.1f}ms" + ) + except Exception as e: + logger.exception( + f"[{self._name}] Failed to process metrics for stage {stage_id}, req {req_id}: {e}", + ) + logger.debug( + f"[{self._name}] Stage-{stage_id} completed request {req_id}; forwarding or finalizing", + ) + stage.set_engine_outputs(engine_outputs) + + if getattr(stage, "final_output", False): + logger.debug( + f"[{self._name}] Request {req_id} finalized at stage-{stage_id}", + ) + + # End-to-end timing and time-per-token for final output + # (only once per request at the designated final stage) + try: + rid_key = str(req_id) + if stage_id == final_stage_id_to_prompt[req_id] and rid_key not in metrics.e2e_done: + metrics.on_finalize_request( + stage_id, + req_id, + _req_start_ts.get(req_id, _wall_start_ts), + ) + except Exception as e: + logger.exception( + f"[{self._name}] Finalize request handling error for req {req_id} at stage {stage_id}: {e}", + ) + yield OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, # type: ignore[attr-defined] + request_output=engine_outputs, + ) + + next_stage_id = stage_id + 1 + if next_stage_id <= final_stage_id_to_prompt[req_id]: + next_stage: OmniStage = self.stage_list[next_stage_id] + try: + next_inputs = next_stage.process_engine_inputs(self.stage_list, [request_id_to_prompt[req_id]]) + except Exception as e: + logger.exception( + f"[{self._name}] Process engine inputs error for req {req_id}" + f" at stage {next_stage_id}: {e}", + ) + continue + sp_next = sampling_params_list[next_stage_id] # type: ignore[index] + + # Check if we have a connector for this edge + connector_key = (str(stage_id), str(next_stage_id)) + connector = self.connectors.get(connector_key) + sent_via_connector = False + if connector: + sent_via_connector = try_send_via_connector( + connector=connector, + stage_id=stage_id, + next_stage_id=next_stage_id, + req_id=req_id, + next_inputs=next_inputs, + sampling_params=sp_next, + original_prompt=request_id_to_prompt[req_id], + next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, + metrics=metrics, + ) + + if not sent_via_connector: + raise RuntimeError( + f"[{self._name}] Failed to send request {req_id} to stage-{next_stage_id} via connector. " + "Configure a connector for this edge or inspect connector logs for details." + ) + logger.debug( + f"[{self._name}] Forwarded request {req_id} to stage-{next_stage_id}", + ) + remaining_by_stage[next_stage_id] += 1 + else: + completed_requests += 1 + if pbar: + final_mod = self.output_modalities[final_stage_id_to_prompt[req_id]] + pbar.unit = "img" if final_mod == "image" else "req" + pbar.update(1) + logger.debug( + f"[{self._name}] Request {req_id} fully completed ({completed_requests}/{total_requests})", + ) + + if not made_progress: + time.sleep(0.005) + logger.debug(f"[{self._name}] All requests completed") + + if pbar: + pbar.close() + + # Summarize and print stats + try: + summary = metrics.build_and_log_summary(final_stage_id_to_prompt) + logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) + except Exception as e: + logger.exception(f"[{self._name}] Failed to build/log summary: {e}") + + @property + def _name(self) -> str: + return "Orchestrator" diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad9a91c80deaadc6e7c574dfdda844241394052 --- /dev/null +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +import uuid +from collections.abc import Sequence + +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_hf_file_to_dict + +from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig +from vllm_omni.diffusion.diffusion_engine import DiffusionEngine +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType +from vllm_omni.outputs import OmniRequestOutput + +# TODO configure logging properly +logging.basicConfig(level=logging.INFO) + +logger = init_logger(__name__) + + +class OmniDiffusion: + """ + It is the main class to interact with vLLM-Omni diffusion models. + It acts as a high-level interface that prepares requests and + delegates the actual diffusion process to the DiffusionEngine. + + You can pass either an `OmniDiffusionConfig` via `od_config`, or + pass kwargs such as `model="Qwen/Qwen-Image"`, + which will be forwarded to `OmniDiffusionConfig.from_kwargs`. + """ + + def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): + # Capture stage info from kwargs before they might be filtered out + stage_id = kwargs.get("stage_id") + engine_input_source = kwargs.get("engine_input_source") + + if od_config is None: + od_config = OmniDiffusionConfig.from_kwargs(**kwargs) + elif isinstance(od_config, dict): + # If config is dict, check it too (priority to kwargs if both exist) + if stage_id is None: + stage_id = od_config.get("stage_id") + if engine_input_source is None: + engine_input_source = od_config.get("engine_input_source") + od_config = OmniDiffusionConfig.from_kwargs(**od_config) + + self.od_config = od_config + + # Inject stage info into omni_kv_config if present + if stage_id is not None: + self.od_config.omni_kv_config.setdefault("stage_id", stage_id) + if engine_input_source is not None: + self.od_config.omni_kv_config.setdefault("engine_input_source", engine_input_source) + + # Diffusers-style models expose `model_index.json` with `_class_name`. + # Bagel models (and other non-diffusers) typically expose `config.json`. + try: + config_dict = get_hf_file_to_dict( + "model_index.json", + od_config.model, + ) + od_config.model_class_name = config_dict.get("_class_name", None) + od_config.update_multimodal_support() + + tf_config_dict = get_hf_file_to_dict( + "transformer/config.json", + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + except (AttributeError, OSError, ValueError): + cfg = get_hf_file_to_dict("config.json", od_config.model) + if cfg is None: + raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") + + model_type = cfg.get("model_type") + architectures = cfg.get("architectures") or [] + if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: + od_config.model_class_name = "BagelPipeline" + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() + else: + raise + + self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config) + + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params: OmniDiffusionSamplingParams, + request_ids: list[str] = [], + ) -> list[OmniRequestOutput]: + if isinstance(prompts, (str, dict)): + prompts = [prompts] + else: + prompts = list(prompts) + + # Check if request_id is provided in kwargs + if len(request_ids) < len(prompts): + request_ids.extend(f"{i + len(request_ids)}_{uuid.uuid4()}" for i in range(len(prompts) - len(request_ids))) + + request = OmniDiffusionRequest(prompts, sampling_params, request_ids) + return self._run_engine(request) + + def _run_engine(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: + return self.engine.step(request) + + def close(self) -> None: + self.engine.close() + + def __del__(self): # pragma: no cover - best effort cleanup + try: + self.close() + except Exception: + pass + + def start_profile(self, trace_filename: str | None = None) -> None: + """Start profiling for the diffusion model. + + Args: + trace_filename: Optional base filename for trace files. + If None, a timestamp-based name will be generated. + """ + if hasattr(self, "engine") and self.engine: + self.engine.start_profile(trace_filename) + else: + raise RuntimeError("Diffusion engine not initialized") + + def stop_profile(self) -> dict: + """Stop profiling and return profiling results. + + Returns: + Dictionary containing paths to trace and table files. + """ + if hasattr(self, "engine") and self.engine: + return self.engine.stop_profile() + else: + raise RuntimeError("Diffusion engine not initialized") diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..0356ea231020c6158e13a287c30904e2b2d58486 --- /dev/null +++ b/vllm_omni/entrypoints/omni_llm.py @@ -0,0 +1,241 @@ +from collections.abc import Callable +from typing import Any + +import cloudpickle +from pydantic import ValidationError +from tqdm import tqdm + +# External library imports (vLLM) +from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field +from vllm.entrypoints.llm import LLM +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors import get_io_processor +from vllm.usage.usage_lib import UsageContext +from vllm.utils.counter import Counter +from vllm.v1.engine.llm_engine import LLMEngine + +from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors + +# Internal imports (our code) +from vllm_omni.engine.arg_utils import OmniEngineArgs +from vllm_omni.engine.input_processor import OmniInputProcessor +from vllm_omni.engine.output_processor import MultimodalOutputProcessor +from vllm_omni.entrypoints.utils import ( + load_stage_configs_from_model, + load_stage_configs_from_yaml, + resolve_model_config_path, +) + +logger = init_logger(__name__) + + +class OmniLLM(LLM): + """Main entry point for vLLM-Omni inference. + + This class extends the base vLLM LLM class with omni-specific + processors for handling multimodal inputs and outputs. It provides + configuration loading for multi-stage pipelines, while stage management + is handled by the Omni class. + + Args: + model: Model name or path to load + stage_configs_path: Optional path to YAML file containing stage + configurations. If None, configurations are loaded from the model. + log_stats: Whether to enable statistics logging + compilation_config: Optional compilation configuration. Can be an + integer (compilation level), dict, or CompilationConfig instance. + hf_overrides: Optional HuggingFace model configuration overrides + structured_outputs_config: Optional structured outputs configuration. + Can be a dict or StructuredOutputsConfig instance. + init_sleep_seconds: Number of seconds to sleep between starting + each stage process during initialization (used by Omni class) + shm_threshold_bytes: Threshold in bytes for using shared memory + for IPC. Objects larger than this threshold will use shared memory. + batch_timeout: Timeout in seconds for batching requests within a stage + init_timeout: Timeout in seconds for waiting for all stages to initialize + **kwargs: Additional keyword arguments passed to the base LLM class + and engine + + Example: + >>> llm = OmniLLM(model="Qwen/Qwen2.5-Omni-7B") + >>> # Stage management is handled by Omni class + """ + + def __init__( + self, + model: str, + stage_configs_path: str | None = None, + log_stats: bool = False, + compilation_config: int | dict[str, Any] | CompilationConfig | None = None, + hf_overrides: dict[str, Any] | None = None, + structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None, + init_sleep_seconds: int = 20, + shm_threshold_bytes: int = 65536, + batch_timeout: int = 10, + init_timeout: int = 300, + **kwargs: Any, + ): + """LLM constructor with omni-specific configuration loading.""" + # Store stage management parameters (used by Omni class) + self.worker_backend = kwargs.get("worker_backend", "multi_process") + self.ray_address = kwargs.get("ray_address", None) + self.batch_timeout = batch_timeout + self._enable_stats: bool = bool(log_stats) + + # Load stage configurations + if stage_configs_path is None: + self.config_path = resolve_model_config_path(model) + self.stage_configs = load_stage_configs_from_model(model) + else: + self.config_path = stage_configs_path + self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) + + # Initialize connectors + self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( + self.config_path, worker_backend=self.worker_backend, shm_threshold_bytes=shm_threshold_bytes + ) + + # Initialize LLM engine + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if "worker_cls" in kwargs: + worker_cls = kwargs["worker_cls"] + # if the worker_cls is not qualified string name, + # we serialize it using cloudpickle to avoid pickling issues + if isinstance(worker_cls, type): + kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + + if "kv_transfer_config" in kwargs and isinstance(kwargs["kv_transfer_config"], dict): + from vllm.config.kv_transfer import KVTransferConfig + + raw_config_dict = kwargs["kv_transfer_config"] + try: + kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict) + except ValidationError as e: + logger.error( + "Failed to convert 'kv_transfer_config' dict to KVTransferConfig object. Dict: %s. Error: %s", + raw_config_dict, + e, + ) + raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e + + # Extract omni_kv_config from kwargs if present (injected by Omni) + omni_kv_config = kwargs.pop("omni_kv_config", None) + + if compilation_config is not None: + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig(level=compilation_config) + elif isinstance(compilation_config, dict): + compilation_config_instance = CompilationConfig( + **{k: v for k, v in compilation_config.items() if is_init_field(CompilationConfig, k)} + ) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = CompilationConfig() + + if structured_outputs_config is not None: + if isinstance(structured_outputs_config, dict): + structured_outputs_instance = StructuredOutputsConfig( + **{k: v for k, v in structured_outputs_config.items() if is_init_field(StructuredOutputsConfig, k)} + ) + else: + structured_outputs_instance = structured_outputs_config + else: + structured_outputs_instance = StructuredOutputsConfig() + + engine_args = OmniEngineArgs( + model=model, + compilation_config=compilation_config_instance, + structured_outputs_config=structured_outputs_instance, + omni_kv_config=omni_kv_config, + **kwargs, + ) + + # Create the Engine (autoselects V0 vs V1) + self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + self.llm_engine.output_processor = MultimodalOutputProcessor( + tokenizer=self.llm_engine.tokenizer, + log_stats=self.llm_engine.log_stats, + engine_core_output_type=engine_args.engine_output_type, + ) + self.llm_engine.input_processor = OmniInputProcessor(vllm_config=self.llm_engine.vllm_config) + self.engine_class = type(self.llm_engine) + + self.request_counter = Counter() + self.default_sampling_params: dict[str, Any] | None = None + + supported_tasks = self.llm_engine.get_supported_tasks() # type: ignore + + logger.info("Supported_tasks: %s", supported_tasks) + + self.supported_tasks = supported_tasks + + # Load the Input/Output processor plugin if any + io_processor_plugin = self.llm_engine.model_config.io_processor_plugin + self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin) + self.model_config = self.llm_engine.model_config + self.input_processor = self.llm_engine.input_processor + + def close(self) -> None: + """Close resources. + + Note: Stage management is now handled by Omni class. + This method closes the LLM engine but not stages. + """ + # Close the LLM engine if it exists + if hasattr(self, "llm_engine") and self.llm_engine is not None: + if hasattr(self.llm_engine, "shutdown"): + self.llm_engine.shutdown() + + def __del__(self) -> None: # best-effort + try: + self.close() + except Exception as e: + logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True) + + def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[RequestOutput | PoolingRequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + pbar = tqdm_func( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), + ) + + # Run the engine. + outputs: list[RequestOutput | PoolingRequestOutput] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + n = len(output.outputs) + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) * n + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum(len(stp.token_ids) for stp in output.outputs) + out_spd = total_out_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s" + pbar.update(n) + else: + pbar.update(1) + if pbar.n == num_requests: + pbar.refresh() + + if use_tqdm: + pbar.close() + # Sort the outputs by the int part of request ID which is in format of 'int-uuid'. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..a2070f89dcfc9a1fb8b079680b7d6f2be737e88b --- /dev/null +++ b/vllm_omni/entrypoints/omni_stage.py @@ -0,0 +1,1593 @@ +""" +Stage manager for orchestrating multiple engines in vLLM-Omni. + +Enhanced to encapsulate per-stage process lifecycle and worker logic +(device setup, LLM init, batching, shared-memory IPC), while preserving +the original input processing utilities for cross-stage data wiring. +""" + +import asyncio +import fcntl +import importlib +import multiprocessing as mp +import os +import queue +import sys +import time +import traceback +from collections.abc import Sequence +from dataclasses import fields +from typing import Any, Literal, cast + +from vllm import PromptType, RequestOutput +from vllm.inputs import TextPrompt +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.tokenizers import TokenizerLike +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine import EngineCoreOutput +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.llm_engine import LLMEngine + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.distributed.omni_connectors import build_stage_connectors +from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector +from vllm_omni.distributed.omni_connectors.connectors.base import OmniConnectorBase +from vllm_omni.distributed.ray_utils.utils import kill_ray_actor, start_ray_actor +from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs +from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion +from vllm_omni.entrypoints.async_omni_llm import AsyncOmniLLM +from vllm_omni.entrypoints.log_utils import count_tokens_from_outputs +from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion +from vllm_omni.entrypoints.omni_llm import OmniLLM +from vllm_omni.entrypoints.stage_utils import ( + SHUTDOWN_TASK, + OmniStageTaskType, + _to_dict, + is_profiler_task, + maybe_dump_to_shm, + set_stage_devices, +) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams, OmniTokensPrompt +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + + +def _resolve_worker_cls(engine_args: dict[str, Any]) -> None: + worker_type = engine_args.pop("worker_type", None) + if not worker_type: + return + if engine_args.get("worker_cls"): + return + from vllm_omni.platforms import current_omni_platform + + worker_type = str(worker_type).lower() + if worker_type == "ar": + engine_args["worker_cls"] = current_omni_platform.get_omni_ar_worker_cls() + elif worker_type == "generation": + engine_args["worker_cls"] = current_omni_platform.get_omni_generation_worker_cls() + else: + raise ValueError(f"Unknown worker_type: {worker_type}") + + +def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: + """Build OmniDiffusionConfig kwargs from engine args.""" + od_config = engine_args.get("od_config", {}) + if not od_config: + od_config = {"model": model} + od_field_names = {f.name for f in fields(OmniDiffusionConfig)} + for key, value in engine_args.items(): + if key in od_field_names: + od_config[key] = value + return od_config + + +class OmniStage: + """Stage manager for orchestrating a single stage in the omni pipeline. + + Encapsulates per-stage process lifecycle and worker logic, including + device setup, LLM initialization, batching, and shared-memory IPC. + Preserves input processing utilities for cross-stage data wiring. + + Args: + stage_config: Stage configuration object containing engine arguments, + runtime settings, and stage-specific parameters + """ + + def __init__(self, stage_config: Any, stage_init_timeout: int = 300): + logger.info(f"[OmniStage] stage_config: {stage_config}") + self.stage_config = stage_config + self.engine = None + self.async_engine = None + self.vllm_config = None + self.tokenizer = None + self.input_preprocessor = None + self.is_tracing_enabled = False + self.stage_id = stage_config.stage_id + self.engine_args = stage_config.engine_args + self.model_stage = stage_config.engine_args.model_stage + self.requires_multimodal_data = getattr(stage_config.runtime, "requires_multimodal_data", False) + self.engine_input_source = getattr(stage_config, "engine_input_source", []) + self.engine_output_type = getattr(stage_config.engine_args, "engine_output_type", None) + self.engine_outputs = None + self.is_comprehension = getattr(stage_config, "is_comprehension", False) + # Support for different stage types: "llm" (default) or "diffusion" + self.stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm") + if hasattr(stage_config, "custom_process_input_func"): + # Import the module specified in the config (already a full module path) + module_path, func_name = stage_config.custom_process_input_func.rsplit(".", 1) + module = importlib.import_module(module_path) + self.custom_process_input_func = getattr(module, func_name) + else: + self.custom_process_input_func = None + + self.final_output = getattr(stage_config, "final_output", False) + self.final_output_type = getattr(stage_config, "final_output_type", None) + default_sampling_params = getattr(stage_config, "default_sampling_params", {}) + # For LLM stage, this can directly be a SamplingParams-compatible dict; + # For diffusion stage, this only serves as default values for diffusion kwargs. + default_sampling_params = _to_dict(default_sampling_params) + # Further convert it to dataclass to check fields + try: + self.default_sampling_params = ( + SamplingParams if self.stage_type == "llm" else OmniDiffusionSamplingParams + )(**default_sampling_params) + except TypeError as error: + raise TypeError(f"Invalid default_sampling_params for stage {self.stage_id}: {error}") from error + # Runtime orchestration state (added) + self._in_q: mp.Queue | None = None + self._out_q: mp.Queue | None = None + self._proc: mp.Process | None = None + self._shm_threshold_bytes: int = 65536 + self._stage_init_timeout: int = stage_init_timeout + + def set_engine(self, engine: LLMEngine) -> None: + """Set the LLM engine for this stage. + + Args: + engine: LLMEngine instance to use for this stage + """ + self.engine = engine + + def set_async_engine(self, async_engine: AsyncLLM) -> None: + """Set the async LLM engine for this stage. + + Args: + async_engine: AsyncLLM instance to use for this stage + """ + self.async_engine = async_engine + + def set_vllm_config(self, vllm_config: Any) -> None: + """Set the vLLM configuration for this stage. + + Args: + vllm_config: VllmConfig instance received from worker process + """ + self.vllm_config = vllm_config + + def set_tokenizer(self, tokenizer: TokenizerLike) -> None: + """Set the tokenizer for this stage. + + Args: + tokenizer: Tokenizer instance received from worker process + """ + self.tokenizer = tokenizer + + def set_input_preprocessor(self, input_preprocessor: InputPreprocessor) -> None: + """Set the input preprocessor for this stage. + + Args: + input_preprocessor: InputPreprocessor instance received from worker process + """ + self.input_preprocessor = input_preprocessor + + def set_is_tracing_enabled(self, is_tracing_enabled: bool) -> None: + """Set whether tracing is enabled for this stage. + + Args: + is_tracing_enabled: Boolean indicating if tracing is enabled + """ + self.is_tracing_enabled = is_tracing_enabled + + def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: + """Set the engine outputs for this stage. + + Args: + engine_outputs: EngineCoreOutput from this stage's processing + """ + self.engine_outputs = engine_outputs + + # ----------------- New Orchestration APIs ----------------- + def attach_queues(self, in_q: mp.Queue, out_q: mp.Queue) -> None: + """Attach input and output queues for IPC communication. + + Args: + in_q: Input queue for receiving tasks from orchestrator + out_q: Output queue for sending results to orchestrator + """ + self._in_q = in_q + self._out_q = out_q + + def stop_profile(self) -> dict: + """Stop profiling by sending a signal to worker and waiting for response.""" + if self._in_q is None or self._out_q is None: + logger.warning(f"[Stage-{self.stage_id}] Queues not initialized, cannot stop profile.") + return {} + + logger.info(f"[Stage-{self.stage_id}] Sending PROFILER_STOP to worker...") + self.submit({"type": OmniStageTaskType.PROFILER_STOP}) + + # Wait for result from worker + try: + # Profiling stop might take time to flush files, give it 600s + response = self._out_q.get(timeout=600) + + if isinstance(response, dict): + if response.get("type") == "profiler_result": + return response.get("data", {}) + elif "error" in response: + logger.error(f"[Stage-{self.stage_id}] Profiler error: {response['error']}") + return {} + + # If we got something else (e.g. late generation result), we might lose it here, + # but usually profiling stop is called when generation is done. + logger.warning( + f"[Stage-{self.stage_id}] Received unexpected message while waiting for profiler: {response}" + ) + return {} + + except queue.Empty: + logger.error(f"[Stage-{self.stage_id}] Timeout waiting for profiler results.") + return {} + + def init_stage_worker( + self, + model: str, + *, + is_async: bool = False, + shm_threshold_bytes: int = 65536, + ctx: mp.context.BaseContext | None = None, + batch_timeout: int = 10, + connectors_config: dict | None = None, + worker_backend: str = "multi_process", + **kwargs: Any, + ) -> None: + """Initialize and start the stage worker process. + + Creates a worker process that runs the LLM engine for this stage. + The worker handles batching, generation, and IPC communication. + + Args: + model: Model name or path to load + is_async: Whether to use async engine (default: False) + shm_threshold_bytes: Threshold for using shared memory for IPC + ctx: Optional multiprocessing context (default: spawn) + batch_timeout: Timeout in seconds for batching requests + connectors_config: Configuration for stage connectors + worker_backend: Backend type ("multi_process" or "ray") + **kwargs: Additional arguments (e.g. ray_placement_group) + + Raises: + AssertionError: If queues are not attached before calling this method + """ + assert self._in_q is not None and self._out_q is not None, "Queues must be attached before start_process" + + if worker_backend == "ray": + ray_placement_group = kwargs.get("ray_placement_group", None) + assert ray_placement_group is not None, "Ray placement group must be provided" + self._shm_threshold_bytes = sys.maxsize + else: + self._shm_threshold_bytes = shm_threshold_bytes + + ctx = ctx or mp.get_context("spawn") + # Prepare lightweight dict config for worker + engine_args = _to_dict(self.engine_args) + runtime_cfg = _to_dict(getattr(self.stage_config, "runtime", {})) + stage_payload: dict[str, Any] = { + "stage_id": self.stage_id, + "engine_args": engine_args, + "runtime": runtime_cfg, + "shm_threshold_bytes": self._shm_threshold_bytes, + "connectors_config": connectors_config or {}, + "stage_type": self.stage_type, + "engine_input_source": self.engine_input_source, + } + try: + old_env = os.environ.get("VLLM_LOGGING_PREFIX") + new_env = f"[Stage-{self.stage_id}] {'' if old_env is None else old_env}" + os.environ["VLLM_LOGGING_PREFIX"] = new_env + if worker_backend == "ray": + if is_async: + self._ray_actor = start_ray_actor( + _stage_worker_async_entry, + ray_placement_group, + self.stage_id, + self, + model=model, + stage_payload=stage_payload, + batch_timeout=batch_timeout, + stage_init_timeout=self._stage_init_timeout, + ) + else: + self._ray_actor = start_ray_actor( + _stage_worker, + ray_placement_group, + self.stage_id, + model=model, + stage_payload=stage_payload, + in_q=self._in_q, + out_q=self._out_q, + batch_timeout=batch_timeout, + stage_init_timeout=self._stage_init_timeout, + ) + else: + if is_async: + self._proc = ctx.Process( + target=_stage_worker_async_entry, + args=( + self, + model, + stage_payload, + batch_timeout, + self._stage_init_timeout, + ), + ) + else: + self._proc = ctx.Process( + target=_stage_worker, + args=( + model, + stage_payload, + self._in_q, + self._out_q, + batch_timeout, + self._stage_init_timeout, + ), + ) + self._proc.start() + finally: + if old_env is None: + os.environ.pop("VLLM_LOGGING_PREFIX", None) + else: + os.environ["VLLM_LOGGING_PREFIX"] = old_env + + def stop_stage_worker(self) -> None: + """Stop the stage worker process gracefully. + + Sends shutdown signal to the worker and waits for it to terminate. + If graceful shutdown fails, forcefully terminates the process. + Handles both multiprocessing Process and Ray Actor. + """ + if self._in_q is not None: + try: + self._in_q.put_nowait(SHUTDOWN_TASK) + except Exception as e: + logger.warning("Failed to send shutdown to in_q: %s", e) + + if hasattr(self, "_ray_actor") and self._ray_actor: + kill_ray_actor(self._ray_actor) + self._ray_actor = None + elif self._proc is not None: + try: + self._proc.join(timeout=5) + except Exception as e: + logger.debug("join() failed: %s", e) + if self._proc.is_alive(): + try: + self._proc.terminate() + except Exception as e: + logger.warning("terminate() failed: %s", e) + + def submit(self, payload: dict[str, Any]) -> None: + """Submit a task to the stage worker. + + Args: + payload: Dictionary containing task data (request_id, engine_inputs, + sampling_params, etc.) + """ + assert self._in_q is not None + + # [Omni] Inject global request_id into additional_information for cross-stage ID consistency + # This allows workers (like GPUARModelRunner) to use the global ID for side-channel + # operations like KV transfer, even if they use internal IDs for execution. + if "request_id" in payload and "engine_inputs" in payload: + req_id = payload["request_id"] + ein = payload["engine_inputs"] + + # Helper to inject into additional_information + def _inject_global_id(target_ein): + # OmniTokensPrompt is a TypedDict at runtime, so we treat it as a dict + if isinstance(target_ein, dict): + if "additional_information" not in target_ein: + target_ein["additional_information"] = {} + + # Ensure additional_information is a dict before assignment + # (in case it was somehow initialized as None or other type) + if target_ein["additional_information"] is None: + target_ein["additional_information"] = {} + + if isinstance(target_ein["additional_information"], dict): + # Wrap in list because OmniInputProcessor requires Tensor or list values + target_ein["additional_information"]["global_request_id"] = [str(req_id)] + + if isinstance(ein, list): + for item in ein: + _inject_global_id(item) + else: + _inject_global_id(ein) + + self._in_q.put(payload) + + def try_collect(self) -> dict[str, Any] | None: + """Try to collect a result from the stage worker without blocking. + + Returns: + Result dictionary if available, None otherwise. Result contains + request_id, engine_outputs (or engine_outputs_shm), and metrics. + """ + assert self._out_q is not None + try: + return self._out_q.get_nowait() + except Exception: + return None + + def process_engine_inputs( + self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None + ) -> list[OmniTokensPrompt | TextPrompt]: + """Process engine inputs for this stage from upstream stage outputs. + + Derives inputs for this stage from outputs of upstream stages. + Uses engine_input_source configuration to determine which upstream + stage outputs to use. Supports custom processing functions. + + Args: + stage_list: List of all stages in the pipeline + prompt: Optional original prompt (for multimodal data preservation) + + Returns: + List of processed engine inputs ready for this stage + + Raises: + ValueError: If engine_input_source is empty or invalid + """ + if self.custom_process_input_func is None: + engine_inputs = [] + if len(self.engine_input_source) == 0: + raise ValueError("engine_input_source is empty") + source_stage_id = self.engine_input_source[0] + source_outputs = stage_list[source_stage_id].engine_outputs + if not isinstance(prompt, list): + prompt = [prompt] + multi_modal_data = { + source_output.request_id: p.get("multi_modal_data", None) + for source_output, p in zip(source_outputs, prompt) + } + + for source_output in source_outputs: + engine_input = OmniTokensPrompt( + prompt_token_ids=source_output.outputs[0].token_ids, + multi_modal_data=( + multi_modal_data[source_output.request_id] + if self.requires_multimodal_data and multi_modal_data + else None + ), + ) + engine_inputs.append(engine_input) + return engine_inputs + + else: + engine_input_source = self.engine_input_source + return self.custom_process_input_func( + stage_list, engine_input_source, prompt, self.requires_multimodal_data + ) + + +def _stage_worker( + model: str, + stage_payload: dict[str, Any], + in_q: mp.Queue, + out_q: mp.Queue, + batch_timeout: int = 10, + stage_init_timeout: int = 300, +) -> None: + """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" + # Use local aliases to avoid conflicts with global imports in worker process + logger.info(f"Starting stage worker with model: {model}") + import multiprocessing as _mp + import os as _os + import time as _time + + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins() + # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / + # GPUARModelRunner) are spawned with a fork-safe method. + # Mooncake / gRPC / RDMA and CUDA/NCCL can deadlock under fork-with-threads. + if _os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": + _os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + logger.info("[Stage] Set VLLM_WORKER_MULTIPROC_METHOD=spawn") + # Best-effort: also force python mp start method in this stage process. + # This may raise if already set; that's fine. + try: + _mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + + stage_id = stage_payload["stage_id"] + engine_args = stage_payload.get("engine_args", {}) + runtime_cfg = stage_payload.get("runtime", {}) + shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) + connectors_config = stage_payload.get("connectors_config", {}) + stage_type: Literal["llm", "diffusion"] = stage_payload.get("stage_type", "llm") + + if stage_type != "diffusion": + _resolve_worker_cls(engine_args) + + # Aggregates for running average + _agg_total_tokens = 0 + _agg_total_gen_time_ms = 0.0 + # Monotonic batch id per stage process for orchestrator dedup on time aggregation + _batch_seq = 0 + + # Device mapping + device_type = None + try: + from vllm_omni.platforms import current_omni_platform + + device_type = current_omni_platform.device_type + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) + except Exception as e: + logger.warning("Device setup failed: %s", e) + + # Sequential initialization on the same device to avoid memory calculation errors + # when multiple instances start simultaneously + # For TP/PP/DP/SP, we need to lock ALL devices that will be used by this stage + lock_files = [] + try: + # Get all parallel sizes from engine_args or parallel_config (defaults to 1) + if "parallel_config" in engine_args: + parallel_config = engine_args["parallel_config"] + tensor_parallel_size = parallel_config.get("tensor_parallel_size", 1) + pipeline_parallel_size = parallel_config.get("pipeline_parallel_size", 1) + data_parallel_size = parallel_config.get("data_parallel_size", 1) + prefill_context_parallel_size = parallel_config.get("prefill_context_parallel_size", 1) + sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1) + cfg_parallel_size = parallel_config.get("cfg_parallel_size", 1) + else: + tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) + data_parallel_size = engine_args.get("data_parallel_size", 1) + prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) + sequence_parallel_size = 1 # not use in omni model + cfg_parallel_size = 1 # not used in omni model + + # Calculate total number of devices needed for this stage + # For a single stage worker: + # - TP: splits model across devices (always needed) + # - PP: splits layers across pipeline stages, but each stage uses TP devices + # - DP: replicates model, but each replica uses TP devices + # - PCP: context parallelism, typically uses TP devices + # - SP: sequence parallelism, typically uses TP devices + # - CFG: Classifier-Free Guidance parallelism for diffusion models + # The number of devices per stage is determined by TP * PP * DP * PCP * SP * CFG size + # (PP/DP/PCP are higher-level parallelism that don't add devices per stage) + num_devices_per_stage = ( + tensor_parallel_size + * pipeline_parallel_size + * data_parallel_size + * prefill_context_parallel_size + * sequence_parallel_size + * cfg_parallel_size + ) + + # Get physical device IDs from device control env var (e.g., CUDA_VISIBLE_DEVICES) + # After set_stage_devices, this env var is set to physical device(s) + device_control_env = current_omni_platform.device_control_env_var + visible_devices_str = _os.environ.get(device_control_env) + physical_devices = [] + + if visible_devices_str: + try: + physical_devices = [int(x.strip()) for x in visible_devices_str.split(",") if x.strip()] + except (ValueError, IndexError): + pass + + if not physical_devices: + # Fallback: use logical device count if device control env var not set + num_devices = current_omni_platform.get_device_count() + physical_devices = list(range(num_devices)) + + # Determine which devices will be used (min of devices per stage and available devices) + num_devices_to_lock = min(num_devices_per_stage, len(physical_devices)) + devices_to_lock = physical_devices[:num_devices_to_lock] + + # Sort devices_to_lock to prevent deadlock (all processes acquire locks in same order) + devices_to_lock = sorted(devices_to_lock) + + logger.debug( + "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d, CFG=%d; will lock %d devices: %s", + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + prefill_context_parallel_size, + sequence_parallel_size, + cfg_parallel_size, + num_devices_to_lock, + devices_to_lock, + ) + + # Acquire exclusive locks for all devices using fcntl.flock + # Locks are automatically released when process dies + wait_start = _time.time() + acquired_lock_fds = [] # Store file descriptors to keep locks alive + + for device_id in devices_to_lock: + lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock" + lock_acquired = False + + while not lock_acquired: + try: + # Open or create the lock file + lock_fd = _os.open(lock_file, _os.O_CREAT | _os.O_RDWR, 0o644) + + # Try to acquire exclusive lock (non-blocking first) + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + # Successfully acquired lock - write PID + _os.ftruncate(lock_fd, 0) # Clear file + _os.write(lock_fd, f"{_os.getpid()}\n".encode()) + _os.fsync(lock_fd) # Ensure written to disk + lock_acquired = True + acquired_lock_fds.append(lock_fd) + logger.debug("Acquired exclusive lock for device %s", device_id) + except BlockingIOError: + # Lock is held by another process + _os.close(lock_fd) + + # Check if we've been waiting too long + if _time.time() - wait_start > stage_init_timeout: + logger.warning( + "Timeout waiting for device %s initialization lock, proceeding anyway", + device_id, + ) + break + + # Wait a bit before retrying + _time.sleep(0.1) + except OSError as e: + # Other error - log and continue without lock + logger.debug( + "Failed to acquire lock for device %s: %s, continuing anyway", + device_id, + e, + ) + try: + _os.close(lock_fd) + except (OSError, NameError): + pass + break + + lock_files = acquired_lock_fds + + # Set FD_CLOEXEC on all lock file descriptors to prevent child processes + # (e.g., EngineCore) from inheriting them, which would cause deadlock + for lock_fd in acquired_lock_fds: + try: + flags = fcntl.fcntl(lock_fd, fcntl.F_GETFD) + fcntl.fcntl(lock_fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + except (OSError, ValueError): + pass + except Exception as e: + logger.debug( + "[Stage-%s] Failed to set up sequential initialization lock: %s", + stage_id, + e, + ) + + # Init engine based on stage_type + logger.debug("[Stage-%s] Initializing %s engine with args keys=%s", stage_id, stage_type, list(engine_args.keys())) + if engine_args.get("async_chunk", False): + logger.debug("[Stage-%s] Async chunk enabled, injecting connectors config", stage_id) + stage_connector_spec = {} + for v in connectors_config.values(): + stage_connector_spec = dict(v.get("spec", {})) + break + engine_args["stage_connector_spec"] = stage_connector_spec + engine_args["stage_id"] = stage_id + if stage_type == "diffusion": + engine_args.pop("model_stage", None) + engine_args.pop("model", None) + stage_engine = OmniDiffusion( + model=model, + stage_id=stage_id, + engine_input_source=stage_payload.get("engine_input_source", []), + **engine_args, + ) + else: + # Default to LLM engine + stage_engine = OmniLLM(model=model, **engine_args) + + # Release all locks AFTER engine initialization completes + for lock_fd in lock_files: + try: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + _os.close(lock_fd) + logger.debug("Released initialization lock (fd=%s)", lock_fd) + except (OSError, ValueError): + pass + lock_files = [] # Clear after release + + logger.debug("Engine initialized") + # Initialize OmniConnectors if configured + connectors: dict[tuple[str, str], OmniConnectorBase] | None = {} + if connectors_config: + connectors = build_stage_connectors( + stage_id=stage_id, + connectors_config=connectors_config, + ) + if connectors is None: + return + + # Signal readiness to orchestrator + try: + out_q.put({"type": "stage_ready", "stage_id": stage_id}) + except Exception: + pass + + max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) + logger.info(f"Max batch size: {max_batch_size}") + + def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: + """Handle profiler task locally in the worker process.""" + if task_type == OmniStageTaskType.PROFILER_START: + if stage_type == "diffusion": + try: + profile_dir = _os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") + _os.makedirs(profile_dir, exist_ok=True) + trace_filename = f"stage_{stage_id}_diffusion_{int(_time.time())}" + stage_engine.start_profile(trace_filename=trace_filename) + logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id) + except Exception as e: + logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e) + else: + try: + stage_engine.start_profile() + logger.info("[Stage-%s] vLLM profiler started", stage_id) + except Exception as e: + logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e) + return {} + + elif task_type == OmniStageTaskType.PROFILER_STOP: + if stage_type == "diffusion": + try: + # CRITICAL: Capture return value + result_data = stage_engine.stop_profile() + logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id) + return result_data + except Exception as e: + logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e) + return {} + else: + try: + stage_engine.stop_profile() + logger.info("[Stage-%s] vLLM profiler stopped", stage_id) + except Exception as e: + logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e) + return {} + return {} + + # Batch processing loop + while True: + task = in_q.get() + + _recv_dequeue_ts = _time.time() + task_type = task.get("type", OmniStageTaskType.GENERATE) + if task_type == OmniStageTaskType.SHUTDOWN: + logger.info("Received shutdown signal") + break + + # Handle profiler control commands + if is_profiler_task(task_type): + profiler_data = handle_profiler_task_local(task_type) + # If it was a STOP command, we must reply to the Orchestrator + if task_type == OmniStageTaskType.PROFILER_STOP: + out_q.put({"type": "profiler_result", "data": profiler_data}) + continue + + batch_tasks: list[dict[str, Any]] = [task] + tasks_failed_to_add_to_batch: list[dict[str, Any]] = [] + start_time = _time.time() + if max_batch_size > 1: + while len(batch_tasks) < max_batch_size: + if not in_q.empty(): + extra = in_q.get_nowait() + if extra == SHUTDOWN_TASK: + in_q.put(SHUTDOWN_TASK) + break + # Handle profiler commands that arrive during batching + extra_type = extra.get("type") if isinstance(extra, dict) else None + if is_profiler_task(extra_type): + p_data = handle_profiler_task_local(extra_type) + if extra_type == OmniStageTaskType.PROFILER_STOP: + out_q.put({"type": "profiler_result", "data": p_data}) + continue + # Ensure that all tasks have the same sampling params + # If no, put them in a temporary container and add back to queue + # This should be always true, because user only calls omni.generate() once and it blocks + # User can only pass one sampling param object, but the list of prompts are separated. + if task.get("sampling_params") != extra.get("sampling_params"): + logger.warning( + """In offline mode, expect all prompts in one `omni.generate()` call to share same sampling params""" # noqa: E501 # line too long + f"""However, prompt {task.get("engine_inputs")} has sampling params {task.get("sampling_params")}, """ # noqa: E501 # line too long + f"""whereas the prompt {extra.get("engine_inputs")} has sampling params {extra.get("sampling_params")}.""" # noqa: E501 # line too long + """The two tasks cannot be combined in one batch request.""" + ) + tasks_failed_to_add_to_batch.append(extra) + else: + batch_tasks.append(extra) + end_time = _time.time() + duration = end_time - start_time + if duration > batch_timeout: + break + else: + continue + else: + end_time = _time.time() + duration = end_time - start_time + _time.sleep(0.05) + if duration > batch_timeout: + break + else: + continue + for task_to_readd in tasks_failed_to_add_to_batch: + in_q.put(task_to_readd) + # Ensure that the popped tasks are with identical sampling params. Take one of them. + batch_engine_sampling_params: OmniSamplingParams = batch_tasks[0]["sampling_params"] + + batch_request_ids: list[Any] = [] + batch_engine_inputs: list[OmniPromptType] = [] + _rx_bytes_by_rid: dict[Any, int] = {} + _rx_decode_ms_by_rid: dict[Any, float] = {} + _in_flight_ms_by_rid: dict[Any, float] = {} + for t in batch_tasks: + rid = t["request_id"] + try: + sent_ts = float(t.get("sent_ts", None)) if isinstance(t, dict) else None + if sent_ts is not None: + _in_flight_ms_by_rid[rid] = (_recv_dequeue_ts - sent_ts) * 1000.0 + else: + _in_flight_ms_by_rid[rid] = 0.0 + except Exception: + _in_flight_ms_by_rid[rid] = 0.0 + + # Resolve input data strictly via connectors if payload + # is larger than shm_threshold_bytes or using other connectors + ein, _rx_metrics = try_recv_via_connector( + task=t, + connectors=connectors, + stage_id=stage_id, + ) + # TODO: hack type annotation for now. + # A better way is to refine type annotation of connection and task/payloads, maybe using template types. + ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) + + if ein is None or _rx_metrics is None: + raise RuntimeError( + f"[Stage-{stage_id}] Missing connector payload for request {rid}. " + "Ensure connectors are configured for all incoming edges." + ) + + _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) + _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) + + batch_request_ids.append(rid) + if isinstance(ein, (str, dict)): + # Types like OmniTextPrompt, TextPrompt are TypedDict, essentially dict and enters this branch + batch_engine_inputs.append(ein) + elif isinstance(ein, Sequence): + batch_engine_inputs.extend(ein) + else: + # Other unknown types, append as-is + batch_engine_inputs.append(ein) + logger.debug( + "Received batch size=%d, request_ids=%s", + len(batch_tasks), + batch_request_ids, + ) + try: + _batch_seq += 1 + gen_outputs: list[OmniRequestOutput | RequestOutput] = [] + _gen_t0 = _time.time() + if stage_type == "diffusion": + stage_engine = cast(OmniDiffusion, stage_engine) + batch_engine_sampling_params = cast(OmniDiffusionSamplingParams, batch_engine_sampling_params) + # Diffusion generate returns results directly, not an iterator + diffusion_results = stage_engine.generate( + batch_engine_inputs, batch_engine_sampling_params, batch_request_ids + ) + gen_outputs.extend(diffusion_results) + # Assign request_ids if not present + for idx, result in enumerate(gen_outputs): + if not hasattr(result, "request_id") or result.request_id is None: + if idx < len(batch_request_ids): + result.request_id = batch_request_ids[idx] + else: + stage_engine = cast(OmniLLM, stage_engine) + batch_engine_sampling_params = cast(SamplingParams, batch_engine_sampling_params) + results = stage_engine.generate( + batch_engine_inputs, # type: ignore # silent complaints about list of subclassed TypedDict + batch_engine_sampling_params, + use_tqdm=False, + ) + gen_outputs.extend(results) + _gen_t1 = _time.time() + _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 + logger.debug(f"Generate done: batch={len(batch_tasks)}, req_ids={batch_request_ids}, gen_ms={_gen_ms:.1f}") + + # Group outputs per request id with fallback + req_to_outputs: dict[Any, list[Any]] = {rid: [] for rid in batch_request_ids} + unmapped: list[Any] = [] + for ro in gen_outputs: + rid = ro.request_id + if rid in req_to_outputs: + req_to_outputs[rid].append(ro) + else: + unmapped.append(ro) + if unmapped: + idx = 0 + for ro in unmapped: + target_rid = batch_request_ids[idx % len(batch_request_ids)] + ro.request_id = target_rid + req_to_outputs[target_rid].append(ro) + idx += 1 + + _agg_total_gen_time_ms += _gen_ms + + # Emit per-request results + for i, rid in enumerate(batch_request_ids): + r_outputs = req_to_outputs.get(rid, []) + _metrics = make_request_stats( + r_outputs, + _gen_ms, + int(_batch_seq), + int(len(batch_request_ids)), + float(_rx_decode_ms_by_rid.get(rid, 0.0)), + int(_rx_bytes_by_rid.get(rid, 0)), + float(_in_flight_ms_by_rid.get(rid, 0.0)), + ) + _agg_total_tokens += _metrics.num_tokens_out + if i == len(batch_request_ids) - 1: + _metrics.stage_stats = make_stage_stats(_agg_total_tokens, _agg_total_gen_time_ms) + else: + _metrics.stage_stats = None + try: + use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) + if use_shm: + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "engine_outputs_shm": payload, + "metrics": _metrics, + } + ) + else: + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "engine_outputs": payload, + "metrics": _metrics, + } + ) + except Exception: + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "engine_outputs": r_outputs, + "metrics": _metrics, + } + ) + logger.debug( + "Enqueued result for request %s to downstream", + rid, + ) + except Exception as e: + logger.exception("Failed on batch %s: %s", batch_request_ids, e) + _tb = traceback.format_exc() + for rid in batch_request_ids: + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "error": str(e), + "error_tb": _tb, + } + ) + + +def _stage_worker_async_entry( + omni_stage: OmniStage, + model: str, + stage_payload: dict[str, Any], + batch_timeout: int = 10, + stage_init_timeout: int = 300, +) -> None: + asyncio.run(_stage_worker_async(omni_stage, model, stage_payload, batch_timeout, stage_init_timeout)) + + +async def _stage_worker_async( + omni_stage: OmniStage, + model: str, + stage_payload: dict[str, Any], + batch_timeout: int = 10, + stage_init_timeout: int = 300, +) -> None: + """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" + # Use local aliases to avoid conflicts with global imports in worker process + import multiprocessing as _mp + import os as _os + import time as _time + + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins() + # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / + # GPUARModelRunner) are spawned with a fork-safe method. + if _os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": + _os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + logger.info("[Stage-async] Set VLLM_WORKER_MULTIPROC_METHOD=spawn") + try: + _mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + + stage_id = stage_payload["stage_id"] + engine_args = stage_payload.get("engine_args", {}) + runtime_cfg = stage_payload.get("runtime", {}) + shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) + connectors_config = stage_payload.get("connectors_config", {}) + stage_type = stage_payload.get("stage_type", "llm") + + if stage_type != "diffusion": + _resolve_worker_cls(engine_args) + + in_q = omni_stage._in_q + out_q = omni_stage._out_q + + # Aggregates for running average + _agg_total_tokens = 0 + _agg_total_gen_time_ms = 0.0 + # Monotonic batch id per stage process for orchestrator dedup on time + # aggregation + _batch_seq = 0 + + # Device mapping + device_type = None + try: + from vllm_omni.platforms import current_omni_platform + + device_type = current_omni_platform.device_type + set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) + except Exception as e: + logger.warning("Device setup failed: %s", e) + + # Initialize OmniConnectors if configured to match sync worker behavior + connectors: dict[Any, Any] = {} + if connectors_config: + built_connectors = build_stage_connectors( + stage_id=stage_id, + connectors_config=connectors_config, + ) + if built_connectors is None: + return + connectors = built_connectors + + # Sequential initialization on the same device to avoid memory calculation errors + # when multiple instances start simultaneously + # For TP/PP/DP/PCP, we need to lock ALL devices that will be used by this stage + lock_files = [] + try: + # Get all parallel sizes from engine_args or parallel_config (defaults to 1) + if "parallel_config" in engine_args: + parallel_config = engine_args["parallel_config"] + tensor_parallel_size = parallel_config.get("tensor_parallel_size", 1) + pipeline_parallel_size = parallel_config.get("pipeline_parallel_size", 1) + data_parallel_size = parallel_config.get("data_parallel_size", 1) + prefill_context_parallel_size = parallel_config.get("prefill_context_parallel_size", 1) + sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1) + cfg_parallel_size = parallel_config.get("cfg_parallel_size", 1) + else: + tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) + data_parallel_size = engine_args.get("data_parallel_size", 1) + prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) + sequence_parallel_size = 1 # not use in omni model + cfg_parallel_size = 1 # not used in omni model + + # Calculate total number of devices needed for this stage + # For a single stage worker: + # - TP: splits model across devices (always needed) + # - PP: splits layers across pipeline stages, but each stage uses TP devices + # - DP: replicates model, but each replica uses TP devices + # - PCP: context parallelism, typically uses TP devices + # - SP: sequence parallelism, typically uses TP devices + # - CFG: Classifier-Free Guidance parallelism for diffusion models + # The number of devices per stage is determined by TP * PP * DP * PCP * SP * CFG size + # (PP/DP/PCP are higher-level parallelism that don't add devices per stage) + num_devices_per_stage = ( + tensor_parallel_size + * pipeline_parallel_size + * data_parallel_size + * prefill_context_parallel_size + * sequence_parallel_size + * cfg_parallel_size + ) + + # Get physical device IDs from device control env var (e.g., CUDA_VISIBLE_DEVICES) + # After set_stage_devices, this env var is set to physical device(s) + device_control_env = current_omni_platform.device_control_env_var + visible_devices_str = _os.environ.get(device_control_env) + physical_devices = [] + + if visible_devices_str: + try: + physical_devices = [int(x.strip()) for x in visible_devices_str.split(",") if x.strip()] + except (ValueError, IndexError): + pass + + if not physical_devices: + # Fallback: use logical device count if device control env var not set + num_devices = current_omni_platform.get_device_count() + physical_devices = list(range(num_devices)) + + # Determine which devices will be used (min of devices per stage and available devices) + num_devices_to_lock = min(num_devices_per_stage, len(physical_devices)) + devices_to_lock = physical_devices[:num_devices_to_lock] + + # Sort devices_to_lock to prevent deadlock (all processes acquire locks in same order) + devices_to_lock = sorted(devices_to_lock) + + logger.debug( + "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d, CFG=%d; will lock %d devices: %s", + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + prefill_context_parallel_size, + sequence_parallel_size, + cfg_parallel_size, + num_devices_to_lock, + devices_to_lock, + ) + + # Acquire exclusive locks for all devices using fcntl.flock + # Locks are automatically released when process dies + wait_start = _time.time() + acquired_lock_fds = [] # Store file descriptors to keep locks alive + + for device_id in devices_to_lock: + lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock" + lock_acquired = False + + while not lock_acquired: + try: + # Open or create the lock file + lock_fd = _os.open(lock_file, _os.O_CREAT | _os.O_RDWR, 0o644) + + # Try to acquire exclusive lock (non-blocking first) + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + # Successfully acquired lock - write PID + _os.ftruncate(lock_fd, 0) # Clear file + _os.write(lock_fd, f"{_os.getpid()}\n".encode()) + _os.fsync(lock_fd) # Ensure written to disk + lock_acquired = True + acquired_lock_fds.append(lock_fd) + logger.debug("Acquired exclusive lock for device %s", device_id) + except BlockingIOError: + # Lock is held by another process + _os.close(lock_fd) + + # Check if we've been waiting too long + if _time.time() - wait_start > stage_init_timeout: + logger.warning( + "Timeout waiting for device %s initialization lock, proceeding anyway with timeout %s", + device_id, + stage_init_timeout, + ) + break + + # Wait a bit before retrying + _time.sleep(0.1) + except OSError as e: + # Other error - log and continue without lock + logger.debug( + "Failed to acquire lock for device %s: %s, continuing anyway", + device_id, + e, + ) + try: + _os.close(lock_fd) + except (OSError, NameError): + pass + break + + lock_files = acquired_lock_fds + except Exception as e: + logger.debug("Failed to set up sequential initialization lock: %s", e) + + # Init engine based on stage_type + logger.debug( + "[Stage-%s] Initializing %s engine with args keys=%s", + stage_id, + stage_type, + list(engine_args.keys()), + ) + if engine_args.get("async_chunk", False): + logger.debug("[Stage-%s] Async chunk enabled, injecting connectors config", stage_id) + stage_connector_spec = {} + for v in connectors_config.values(): + stage_connector_spec = dict(v.get("spec", {})) + break + engine_args["stage_connector_spec"] = stage_connector_spec + engine_args["stage_id"] = stage_id + try: + if stage_type == "diffusion": + # For diffusion, we need to extract diffusion-specific config + od_config = _build_od_config(engine_args, model) + + # Inject omni config for worker to access stage info + if "omni_kv_config" not in od_config: + od_config["omni_kv_config"] = {} + od_config["omni_kv_config"]["stage_id"] = stage_id + od_config["omni_kv_config"]["engine_input_source"] = stage_payload.get("engine_input_source", []) + + logger.debug(f"[Stage-%s] Initializing diffusion engine with config: {od_config}", stage_id) + stage_engine = AsyncOmniDiffusion( + model=model, + od_config=od_config, + **{k: v for k, v in engine_args.items() if k not in {"od_config", "model"}}, + ) + vllm_config = None # Diffusion doesn't use vllm_config + else: + omni_engine_args = AsyncOmniEngineArgs(model=model, **engine_args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = omni_engine_args.create_engine_config(usage_context=usage_context) + stage_engine = AsyncOmniLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + engine_args=omni_engine_args, + ) + finally: + # Release all locks by closing file descriptors + # Locks are automatically released when file descriptors are closed + # or when process dies + for lock_fd in lock_files: + try: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + _os.close(lock_fd) + logger.debug("Released initialization lock (fd=%s)", lock_fd) + except (OSError, ValueError): + pass + omni_stage.set_async_engine(stage_engine) + if hasattr(omni_stage.async_engine, "log_stats") and omni_stage.async_engine.log_stats: + + async def _force_log(): + try: + while True: + await asyncio.sleep(10.0) + await omni_stage.async_engine.do_log_stats() + except asyncio.CancelledError: + pass + + log_stats_task = asyncio.create_task(_force_log()) + else: + log_stats_task = None + + # Don't keep the dummy data in memory (only for LLM engines) + if stage_type != "diffusion": + await stage_engine.reset_mm_cache() + logger.debug("[Stage-%s] Engine initialized", stage_id) + + async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: + """Handle profiler task asynchronously for both LLM and diffusion stages.""" + if task_type == OmniStageTaskType.PROFILER_START: + if stage_type == "diffusion": + try: + # Sync call is safe here — diffusion profiling is lightweight + profile_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") + os.makedirs(profile_dir, exist_ok=True) + trace_filename = f"stage_{stage_id}_diffusion_{int(time.time())}" + stage_engine.start_profile(trace_filename=trace_filename) + logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id) + except Exception as e: + logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e) + else: + try: + await stage_engine.start_profile() + logger.info("[Stage-%s] vLLM profiler started", stage_id) + except Exception as e: + logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e) + + elif task_type == OmniStageTaskType.PROFILER_STOP: + if stage_type == "diffusion": + try: + trace_files = stage_engine.stop_profile() + logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id) + if trace_files: + logger.info("Diffusion trace files: %s", trace_files) + except Exception as e: + logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e) + else: + try: + await stage_engine.stop_profile() + logger.info("[Stage-%s] vLLM profiler stopped", stage_id) + except Exception as e: + logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e) + + # Signal readiness to orchestrator and send vllm_config back to main process + try: + # Send vllm_config back to main process so it can be accessed via + # get_vllm_config(). This is needed because async_engine is only available + # in the worker process + + # input_preprocessor = await stage_engine.get_input_preprocessor() + stage_ready_payload = { + "type": "stage_ready", + "stage_id": stage_id, + "vllm_config": vllm_config, + "tokenizer": getattr(stage_engine, "tokenizer", None), + } + # Only add is_tracing_enabled for LLM engines + if stage_type != "diffusion": + stage_ready_payload["is_tracing_enabled"] = await stage_engine.is_tracing_enabled() + out_q.put(stage_ready_payload) + except Exception as e: + logger.warning("Failed to send stage ready signal: %s", e) + generation_out_q = asyncio.Queue() + + # Batch processing loop + _rx_bytes_by_rid: dict[Any, int] = {} + _rx_decode_ms_by_rid: dict[Any, float] = {} + _in_flight_ms_by_rid: dict[Any, float] = {} + + async def generation_single_request(task: dict[str, Any]): + _recv_dequeue_ts = _time.time() + rid = task["request_id"] + try: + sent_ts = float(task.get("sent_ts", None)) if isinstance(task, dict) else None + if sent_ts is not None: + _in_flight_ms_by_rid[rid] = (_recv_dequeue_ts - sent_ts) * 1000.0 + else: + _in_flight_ms_by_rid[rid] = 0.0 + except Exception: + _in_flight_ms_by_rid[rid] = 0.0 + try: + ein, _rx_metrics = try_recv_via_connector( + task=task, + connectors=connectors, + stage_id=stage_id, + ) + # TODO: hack type annotation for now. + # A better way is to refine type annotation of connection and task/payloads, maybe using template types. + ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) + + if ein is None or _rx_metrics is None: + raise RuntimeError( + f"[Stage-{stage_id}] Missing connector payload for request {rid}. " + "Ensure connectors are configured for all incoming edges." + ) + _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) + _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) + + logger.debug("Received batch size=1, request_ids=%s", rid) + _gen_t0 = _time.time() + if isinstance(ein, Sequence) and not isinstance(ein, str): + ein = ein[0] + + if stage_type == "diffusion": + diffusion_sampling_params = cast(OmniDiffusionSamplingParams, task["sampling_params"]) + # AsyncOmniDiffusion.generate returns a single result, not an async generator + gen_output = await cast(AsyncOmniDiffusion, stage_engine).generate(ein, diffusion_sampling_params, rid) + _gen_t1 = _time.time() + _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 + await generation_out_q.put((rid, gen_output, _gen_ms)) + else: + ein = cast(PromptType, ein) + llm_sampling_params: SamplingParams = task["sampling_params"] + gen_output = None + async for res in cast(AsyncLLM, stage_engine).generate(ein, llm_sampling_params, rid): + gen_output = res + _gen_t1 = _time.time() + _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 + _gen_t0 = _gen_t1 + await generation_out_q.put((rid, gen_output, _gen_ms)) + except Exception as e: + logger.exception("Failed on request %s: %s", rid, e) + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "error": str(e), + } + ) + + _batch_gen_t0 = _time.time() + while True: + try: + task = in_q.get_nowait() + task_type = task.get("type", OmniStageTaskType.GENERATE) + if task_type == OmniStageTaskType.SHUTDOWN: + logger.debug("Received shutdown signal") + stage_engine.shutdown() + break + elif task_type == OmniStageTaskType.ABORT: + rid = task["request_id"] + asyncio.create_task(stage_engine.abort(rid)) + elif is_profiler_task(task_type): + await handle_profiler_task_async(task_type) + else: + asyncio.create_task(generation_single_request(task)) + + except queue.Empty: + await asyncio.sleep(0.001) + batch_request_outputs: list[Any] = [] + batch_request_ids: list[Any] = [] + _gen_ms_list = [] + batch_metrics: list[Any] = [] + while True: + try: + rid, gen_output, _gen_ms = generation_out_q.get_nowait() + _metrics = make_request_stats( + [gen_output], + _gen_ms, + int(_batch_seq), + 1, # temporarily set to 1 + float(_rx_decode_ms_by_rid.get(rid, 0.0)), + int(_rx_bytes_by_rid.get(rid, 0)), + float(_in_flight_ms_by_rid.get(rid, 0.0)), + ) + batch_metrics.append(_metrics) + batch_request_outputs.append(gen_output) + _gen_ms_list.append(_gen_ms) + batch_request_ids.append(rid) + _agg_total_tokens += _metrics.num_tokens_out + except asyncio.QueueEmpty: + await asyncio.sleep(0.001) + break + + if not batch_request_outputs: + continue + _batch_seq += 1 + + _batch_gen_t1 = _time.time() + _agg_total_gen_time_ms += (_batch_gen_t1 - _batch_gen_t0) * 1000 + _batch_gen_t0 = _batch_gen_t1 + for idx, metrics in enumerate(batch_metrics): + metrics.batch_size = len(batch_metrics) + if idx == len(batch_metrics) - 1: + metrics.stage_stats = make_stage_stats(_agg_total_tokens, _agg_total_gen_time_ms) + + logger.debug("Sending outputs to main process") + for rid, output, _gen_ms, _metrics in zip( + batch_request_ids, batch_request_outputs, _gen_ms_list, batch_metrics + ): + try: + r_outputs = [output_strip(output, omni_stage)] + use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) + if use_shm: + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "engine_outputs_shm": payload, + "metrics": _metrics, + } + ) + else: + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "engine_outputs": payload, + "metrics": _metrics, + } + ) + logger.debug(f"Enqueued req={rid}, use_shm={use_shm}, tokens_out={_metrics.num_tokens_out}") + except Exception as e: + logger.exception( + "Failed to enqueue result for request %s: %s", + rid, + e, + ) + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "engine_outputs": r_outputs, + "metrics": _metrics, + } + ) + logger.debug("Enqueued result for request %s to downstream", rid) + if log_stats_task is not None: + log_stats_task.cancel() + logger.info("Stage worker exiting") + + +def count_prompt_tokens_from_outputs(engine_outputs: list[Any]) -> int: + """Count prompt tokens from engine outputs.""" + total = 0 + for _ro in engine_outputs: + try: + prompt_token_ids = getattr(_ro, "prompt_token_ids", None) + if prompt_token_ids is not None: + total += len(prompt_token_ids) + except Exception: + pass + return total + + +def make_request_stats( + req_output: list[Any], + stage_gen_time_ms: float, + batch_id: int, + batch_size: int, + rx_decode_time_ms: float, + rx_transfer_bytes: int, + rx_in_flight_time_ms: float, +): + from vllm_omni.entrypoints.log_utils import ( + StageRequestMetrics, + ) + + num_tokens_in = count_prompt_tokens_from_outputs(req_output) + num_tokens_out = count_tokens_from_outputs(req_output) + return StageRequestMetrics( + num_tokens_in=num_tokens_in, + num_tokens_out=num_tokens_out, + stage_gen_time_ms=stage_gen_time_ms, + batch_id=batch_id, + batch_size=batch_size, + rx_decode_time_ms=rx_decode_time_ms, + rx_transfer_bytes=rx_transfer_bytes, + rx_in_flight_time_ms=rx_in_flight_time_ms, + stage_stats=None, + ) + + +def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float): + from vllm_omni.entrypoints.log_utils import StageStats + + return StageStats(total_token=_agg_total_tokens, total_gen_time=_agg_total_gen_time_ms) + + +def output_strip(r_output: RequestOutput | OmniRequestOutput, omni_stage: OmniStage): + """ + Strip unnecessary multimodal outputs from stages results, + in order to: + - reduce memory usage + - reduce transfer & serialization overhead + """ + + # check multimodal data is required by stage output config. + if omni_stage.final_output and omni_stage.final_output_type != "text": + return r_output + + # If the request has already finished, should not be altered. + if getattr(r_output, "finished", False): + return r_output + + mm_output = getattr(r_output, "multimodal_output", None) + if mm_output is not None: + r_output.multimodal_output = {} + + outputs = getattr(r_output, "outputs", None) + if outputs is not None: + for out in outputs: + if getattr(out, "multimodal_output", None): + out.multimodal_output = {} + + return r_output diff --git a/vllm_omni/entrypoints/openai/__init__.py b/vllm_omni/entrypoints/openai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e27cb238c5446cc50d5a3782691ba7d84a70be31 --- /dev/null +++ b/vllm_omni/entrypoints/openai/__init__.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +OpenAI-compatible API entrypoints for vLLM-Omni. + +Provides: +- omni_run_server: Main server entry point (auto-detects model type) +- OmniOpenAIServingChat: Unified chat completion handler for both LLM and diffusion models +""" + +from vllm_omni.entrypoints.openai.api_server import ( + build_async_omni, + omni_init_app_state, + omni_run_server, +) +from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + +__all__ = [ + # Server functions + "omni_run_server", + "build_async_omni", + "omni_init_app_state", + # Serving classes + "OmniOpenAIServingChat", +] diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..fb52c7e464d18860815982abc121727d732dd019 --- /dev/null +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -0,0 +1,1453 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import io +import json +import multiprocessing +import multiprocessing.forkserver as forkserver +import os + +# Image generation API imports +import random +import time +import uuid +from argparse import Namespace +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from http import HTTPStatus +from typing import Annotated, Any, cast + +import httpx +import vllm.envs as envs +from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile +from fastapi.responses import JSONResponse, StreamingResponse +from PIL import Image +from starlette.datastructures import State +from starlette.routing import Route +from vllm import SamplingParams +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.anthropic.serving import AnthropicServingMessages +from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer +from vllm.entrypoints.openai.api_server import base, load_log_config +from vllm.entrypoints.openai.api_server import build_app as build_openai_app +from vllm.entrypoints.openai.api_server import setup_server as setup_openai_server +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, +) + +# yapf conflicts with isort for this block +# yapf: disable +# yapf: enable +from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion +from vllm.entrypoints.openai.engine.protocol import ( + ErrorResponse, + ModelCard, + ModelList, + ModelPermission, +) +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.openai.orca_metrics import metrics_header +from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses +from vllm.entrypoints.openai.translations.serving import ( + OpenAIServingTranscription, + OpenAIServingTranslation, +) +from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.pooling.classify.serving import ServingClassification +from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding +from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling +from vllm.entrypoints.pooling.score.serving import ServingScores +from vllm.entrypoints.serve.disagg.serving import ServingTokens +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.entrypoints.utils import ( + load_aware_call, + process_lora_modules, + with_cancellation, +) +from vllm.logger import init_logger +from vllm.tasks import POOLING_TASKS +from vllm.tool_parsers import ToolParserManager +from vllm.utils.system_utils import decorate_logs + +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.openai.image_api_utils import ( + encode_image_base64, + parse_size, +) +from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest +from vllm_omni.entrypoints.openai.protocol.images import ( + ImageData, + ImageGenerationRequest, + ImageGenerationResponse, +) +from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id + +logger = init_logger(__name__) +router = APIRouter() + + +def _remove_route_from_router( + router: APIRouter, + path: str, + methods: set[str] | None = None, +) -> None: + methods_set = {method.upper() for method in methods} if methods else None + for route in list(router.routes): + if getattr(route, "path", None) != path: + continue + if methods_set is not None: + route_methods = {method.upper() for method in (getattr(route, "methods", None) or set())} + if not (route_methods & methods_set): + continue + router.routes.remove(route) + + +ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format" + + +def _remove_route_from_app(app, path: str, methods: set[str] | None = None): + """Remove a route from the app by path and optionally by methods. + + OMNI: used to override upstream /v1/chat/completions with omni behavior. + """ + routes_to_remove = [] + for route in app.routes: + if isinstance(route, Route) and route.path == path: + if methods is None or (hasattr(route, "methods") and route.methods & methods): + routes_to_remove.append(route) + + for route in routes_to_remove: + app.routes.remove(route) + + +class _DiffusionServingModels: + """Minimal OpenAIServingModels implementation for diffusion-only servers. + + vLLM's /v1/models route expects `app.state.openai_serving_models` to expose + `show_available_models()`. In pure diffusion mode we don't initialize the + full OpenAIServingModels (it depends on LLM-specific processors), so we + provide a lightweight fallback. + """ + + def __init__(self, base_model_paths: list[BaseModelPath]) -> None: + self._base_model_paths = base_model_paths + + async def show_available_models(self) -> ModelList: + return ModelList( + data=[ + ModelCard( + id=base_model.name, + root=base_model.model_path, + permission=[ModelPermission()], + ) + for base_model in self._base_model_paths + ] + ) + + +# Server entry points + + +async def omni_run_server(args, **uvicorn_kwargs) -> None: + """Run a single-worker API server. + + Unified entry point that automatically handles both LLM and Diffusion models + through AsyncOmni, which manages multi-stage pipelines. + """ + # Suppress Pydantic serialization warnings globally for multimodal content + # (e.g., when ChatMessage.content is a list instead of str) + import warnings as warnings_module + + warnings_module.filterwarnings("ignore", message=".*Pydantic.*serialization.*", category=UserWarning) + warnings_module.filterwarnings("ignore", message=".*PydanticSerializationUnexpectedValue.*", category=UserWarning) + + # Add process-specific prefix to stdout and stderr. + decorate_logs("APIServer") + + listen_address, sock = setup_openai_server(args) + + # Unified use of omni_run_server_worker, AsyncOmni automatically handles LLM and Diffusion models + await omni_run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +async def omni_run_server_worker(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None: + """Run a single API server worker.""" + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + from vllm.reasoning import ReasoningParserManager + + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) + + # Load logging config for uvicorn if specified + log_config = load_log_config(getattr(args, "log_config_file", None)) + if log_config is not None: + uvicorn_kwargs["log_config"] = log_config + + async with build_async_omni( + args, + client_config=client_config, + ) as engine_client: + supported_tasks: tuple[str, ...] + if hasattr(engine_client, "get_supported_tasks"): + supported_tasks = tuple(await engine_client.get_supported_tasks()) + else: + supported_tasks = ("generate",) + if not supported_tasks: + supported_tasks = ("generate",) + + app = build_openai_app(args) + # OMNI: Remove upstream routes that we override with omni-specific handlers + _remove_route_from_app(app, "/v1/chat/completions", {"POST"}) + _remove_route_from_app(app, "/v1/models", {"GET"}) # Remove upstream /v1/models to use omni's handler + app.include_router(router) + + await omni_init_app_state(engine_client, app.state, args) + + vllm_config = await engine_client.get_vllm_config() + + # Check if pure diffusion mode (vllm_config will be None) + is_pure_diffusion = vllm_config is None + if is_pure_diffusion: + logger.info( + "Starting vLLM API server (pure diffusion mode) on %s", + listen_address, + ) + else: + logger.info( + "Starting vLLM API server %d on %s", + vllm_config.parallel_config._api_process_rank, + listen_address, + ) + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, + h11_max_header_count=args.h11_max_header_count, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +@asynccontextmanager +async def build_async_omni( + args: Namespace, + *, + disable_frontend_multiprocessing: bool | None = None, + client_config: dict[str, Any] | None = None, +) -> AsyncIterator[EngineClient]: + """Build an AsyncOmni instance from command-line arguments. + + Creates an async context manager that yields an AsyncOmni instance + configured from the provided arguments. Handles forkserver setup if + needed and ensures proper cleanup on exit. + + Args: + args: Parsed command-line arguments containing model and configuration + disable_frontend_multiprocessing: Optional flag to disable frontend + multiprocessing (deprecated in V1) + client_config: Optional client configuration dictionary + + Yields: + EngineClient instance (AsyncOmni) ready for use + """ + if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver": + # The executor is expected to be mp. + # Pre-import heavy modules in the forkserver process + logger.debug("Setup forkserver with pre-imports") + multiprocessing.set_start_method("forkserver") + multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"]) + forkserver.ensure_running() + logger.debug("Forkserver setup complete!") + + # Context manager to handle async_omni lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + async with build_async_omni_from_stage_config( + args, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, + ) as async_omni: + yield async_omni + + +@asynccontextmanager +async def build_async_omni_from_stage_config( + args: Namespace, + *, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[EngineClient]: + """Create AsyncOmni from stage configuration. + + Creates an AsyncOmni instance either in-process or using multiprocess + RPC. Loads stage configurations from the model or from a specified path. + + Args: + args: Parsed command-line arguments containing model and stage configs + disable_frontend_multiprocessing: Flag to disable frontend multiprocessing + (deprecated in V1) + client_config: Optional client configuration dictionary + + Yields: + EngineClient instance (AsyncOmni) ready for use + + Note: + Stage configurations are loaded from args.stage_configs_path if provided, + otherwise from the model's default configuration. + """ + + # V1 AsyncLLM. + if disable_frontend_multiprocessing: + logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.") + + async_omni: EngineClient | None = None + + try: + # Convert args Namespace to kwargs dict for AsyncOmni to use + kwargs = vars(args).copy() + # Remove model as it will be passed separately + kwargs.pop("model", None) + async_omni = AsyncOmni(model=args.model, **kwargs) + + # # Don't keep the dummy data in memory + # await async_llm.reset_mm_cache() + + yield async_omni + finally: + if async_omni: + async_omni.shutdown() + + +async def omni_init_app_state( + engine_client: EngineClient, + state: State, + args: Namespace, +) -> None: + """Initialize the FastAPI application state for omni API server. + + Sets up the application state with model information, request logger, + and other server configuration needed for handling API requests. + Automatically detects pure diffusion mode (single diffusion stage) and + handles it appropriately. + + Args: + engine_client: Engine client instance (AsyncOmni) + state: FastAPI application state object to initialize + args: Parsed command-line arguments + """ + # Get vllm_config from engine_client (following 0.14.0 pattern) + vllm_config = await engine_client.get_vllm_config() + + # Detect if it's pure Diffusion mode (single stage and is Diffusion) + is_pure_diffusion = False + if hasattr(engine_client, "stage_configs") and engine_client.stage_configs: + stage_configs = engine_client.stage_configs + if len(stage_configs) == 1: + stage_type = stage_configs[0].get("stage_type", "llm") + if stage_type == "diffusion": + is_pure_diffusion = True + logger.info("Detected pure diffusion mode (single diffusion stage)") + + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.enable_log_requests: + request_logger = RequestLogger(max_log_len=args.max_log_len) + else: + request_logger = None + + base_model_paths = [BaseModelPath(name=name, model_path=args.model) for name in served_model_names] + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + state.args = args + + # For omni models + state.stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None + + # Pure Diffusion mode: use simplified initialization logic + if is_pure_diffusion: + model_name = served_model_names[0] if served_model_names else args.model + state.vllm_config = None + state.diffusion_engine = engine_client + state.openai_serving_models = _DiffusionServingModels(base_model_paths) + # OMNI: tokenization endpoints are not supported in pure diffusion mode. + state.openai_serving_tokenization = None + + # Use for_diffusion method to create chat handler + state.openai_serving_chat = OmniOpenAIServingChat.for_diffusion( + diffusion_engine=engine_client, # type: ignore + model_name=model_name, + ) + + state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) + state.server_load_metrics = 0 + logger.info("Pure diffusion API server initialized for model: %s", model_name) + return + + # LLM or multi-stage mode: use standard initialization logic + if vllm_config is None: + # Try to get vllm_config from engine_client + vllm_config = await engine_client.get_vllm_config() + if vllm_config is None: + logger.warning("vllm_config is None, some features may not work correctly") + + state.vllm_config = vllm_config + + # Get supported tasks + supported_tasks: set[str] = {"generate"} + if hasattr(engine_client, "get_supported_tasks"): + supported_tasks = set(await engine_client.get_supported_tasks()) + logger.info("Supported tasks: %s", supported_tasks) + + resolved_chat_template = load_chat_template(args.chat_template) + + if args.tool_server == "demo": + tool_server: ToolServer | None = DemoToolServer() + assert isinstance(tool_server, DemoToolServer) + await tool_server.init_and_validate() + elif args.tool_server: + tool_server = MCPToolServer() + await tool_server.add_tool_server(args.tool_server) + else: + tool_server = None + + # Merge default_mm_loras into the static lora_modules + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config is not None and vllm_config.lora_config is not None + else {} + ) + lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) + + # Ensure input_processor, io_processor, and model_config exist for OpenAIServingModels compatibility + if ( + not hasattr(engine_client, "input_processor") + or engine_client.input_processor is None + or not hasattr(engine_client, "io_processor") + or engine_client.io_processor is None + or not hasattr(engine_client, "model_config") + or engine_client.model_config is None + ): + if vllm_config is not None: + # Try to initialize processors if vllm_config is available + try: + from vllm.plugins.io_processors import get_io_processor + + from vllm_omni.engine.input_processor import OmniInputProcessor + + tokenizer = await engine_client.get_tokenizer() + if tokenizer is not None: + # Initialize input_processor + # OMNI: OmniInputProcessor creates tokenizer internally from vllm_config + if not hasattr(engine_client, "input_processor") or engine_client.input_processor is None: + engine_client.input_processor = OmniInputProcessor( + vllm_config=vllm_config, + ) + logger.info("Initialized input_processor for AsyncOmni") + + # Initialize model_config + if not hasattr(engine_client, "model_config") or engine_client.model_config is None: + engine_client.model_config = vllm_config.model_config + logger.info("Initialized model_config for AsyncOmni") + + # Initialize io_processor + if not hasattr(engine_client, "io_processor") or engine_client.io_processor is None: + model_config = ( + engine_client.model_config + if hasattr(engine_client, "model_config") + else vllm_config.model_config + ) + io_processor_plugin = model_config.io_processor_plugin + engine_client.io_processor = get_io_processor(vllm_config, io_processor_plugin) + logger.info("Initialized io_processor for AsyncOmni") + else: + logger.warning("Cannot initialize processors: tokenizer is None. OpenAIServingModels may fail.") + except Exception as e: + logger.warning( + "Failed to initialize processors for AsyncOmni: %s. OpenAIServingModels may fail.", + e, + ) + else: + logger.warning("Cannot initialize processors: vllm_config is None. OpenAIServingModels may fail.") + + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + ) + await state.openai_serving_models.init_static_loras() + + state.openai_serving_responses = ( + OpenAIServingResponses( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + tool_server=tool_server, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_chat = ( + OmniOpenAIServingChat( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + default_chat_template_kwargs=args.default_chat_template_kwargs, + trust_request_chat_template=args.trust_request_chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + enable_log_deltas=args.enable_log_deltas, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + # Warm up chat template processing to avoid first-request latency + if state.openai_serving_chat is not None: + await state.openai_serving_chat.warmup() + + state.openai_serving_completion = ( + OpenAIServingCompletion( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_pooling = ( + OpenAIServingPooling( + engine_client, + state.openai_serving_models, + supported_tasks=supported_tasks, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if any(task in POOLING_TASKS for task in supported_tasks) + else None + ) + state.openai_serving_embedding = ( + OpenAIServingEmbedding( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "embed" in supported_tasks + else None + ) + state.openai_serving_classification = ( + ServingClassification( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "classify" in supported_tasks + else None + ) + state.openai_serving_scores = ( + ServingScores( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + score_template=resolved_chat_template, + log_error_stack=args.log_error_stack, + ) + if ("embed" in supported_tasks or "score" in supported_tasks) + else None + ) + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + state.openai_serving_transcription = ( + OpenAIServingTranscription( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) + state.openai_serving_translation = ( + OpenAIServingTranslation( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) + state.anthropic_serving_messages = ( + AnthropicServingMessages( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "generate" in supported_tasks + else None + ) + state.serving_tokens = ( + ServingTokens( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + log_error_stack=args.log_error_stack, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_log_outputs=args.enable_log_outputs, + force_no_detokenize=args.tokens_only, + ) + if "generate" in supported_tasks + else None + ) + + state.openai_serving_speech = OmniOpenAIServingSpeech( + engine_client, state.openai_serving_models, request_logger=request_logger + ) + + state.enable_server_load_tracking = args.enable_server_load_tracking + state.server_load_metrics = 0 + + +def Omnichat(request: Request) -> OmniOpenAIServingChat | None: + return request.app.state.openai_serving_chat + + +def Omnispeech(request: Request) -> OmniOpenAIServingSpeech | None: + return request.app.state.openai_serving_speech + + +@router.post( + "/v1/chat/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + metrics_header_format = raw_request.headers.get(ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, "") + handler = Omnichat(raw_request) + if handler is None: + base_server = getattr(raw_request.app.state, "openai_serving_tokenization", None) + if base_server is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND.value, + detail="The model does not support Chat Completions API", + ) + return base_server.create_error_response(message="The model does not support Chat Completions API") + try: + generator = await handler.create_chat_completion(request, raw_request) + except Exception as e: + logger.exception("Chat completion failed: %s", e) + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), + status_code=generator.error.code if generator.error else 400, + ) + + elif isinstance(generator, ChatCompletionResponse): + # Completely bypass Pydantic serialization warnings for multimodal content + # by converting to dict first, then serializing with warnings suppressed + import json as json_lib + import warnings as warnings_module + + # Temporarily suppress ALL Pydantic UserWarnings during serialization + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning) + warnings_module.filterwarnings("ignore", message=".*Pydantic.*", category=UserWarning) + try: + # Use serialize_as_any=True to bypass type checking + response_dict = generator.model_dump(mode="json", serialize_as_any=True, warnings="none") + return JSONResponse( + content=response_dict, + headers=metrics_header(metrics_header_format), + ) + except Exception: + # Fallback: convert to JSON string and parse back to avoid any serialization issues + try: + response_json = generator.model_dump_json(warnings="none", serialize_as_any=True) + response_dict = json_lib.loads(response_json) + return JSONResponse( + content=response_dict, + headers=metrics_header(metrics_header_format), + ) + except Exception: + # Last resort: regular dump with warnings suppressed + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning) + return JSONResponse( + content=generator.model_dump(mode="json", warnings="none"), + headers=metrics_header(metrics_header_format), + ) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +_remove_route_from_router(router, "/v1/audio/speech", {"POST"}) + + +@router.post( + "/v1/audio/speech", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"audio/*": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request): + handler = Omnispeech(raw_request) + if handler is None: + base_server = getattr(raw_request.app.state, "openai_serving_tokenization", None) + if base_server is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND.value, + detail="The model does not support Speech API", + ) + return base_server.create_error_response(message="The model does not support Speech API") + try: + return await handler.create_speech(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e + + +@router.get( + "/v1/audio/voices", + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def list_voices(raw_request: Request): + """List available TTS voices/speakers from the loaded model.""" + handler = Omnispeech(raw_request) + if handler is None: + return base(raw_request).create_error_response(message="The model does not support Speech API") + + speakers = sorted(handler.supported_speakers) if handler.supported_speakers else [] + return JSONResponse(content={"voices": speakers}) + + +# Health and Model endpoints for diffusion mode + + +# Remove existing health endpoint if present (from vllm imports) +# to ensure our handler takes precedence +_remove_route_from_router(router, "/health") + + +@router.get("/health") +async def health(raw_request: Request) -> JSONResponse: + """Health check endpoint that works for both LLM and diffusion modes. + + Returns 200 OK if the server is healthy. + For LLM mode: delegates to engine_client health check + For diffusion mode: checks if diffusion_engine is running + """ + # Check if we're in diffusion mode + diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) + if diffusion_engine is not None: + # Diffusion mode health check + if hasattr(diffusion_engine, "is_running") and diffusion_engine.is_running: + return JSONResponse(content={"status": "healthy"}) + return JSONResponse( + content={"status": "unhealthy", "reason": "Diffusion engine is not running"}, + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + ) + + # LLM mode - delegate to engine_client + engine_client = getattr(raw_request.app.state, "engine_client", None) + if engine_client is not None: + await engine_client.check_health() + return JSONResponse(content={"status": "healthy"}) + + return JSONResponse( + content={"status": "unhealthy", "reason": "No engine initialized"}, + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + ) + + +# Remove existing models endpoint if present (from vllm imports) +# to ensure our handler takes precedence +_remove_route_from_router(router, "/v1/models") + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request) -> JSONResponse: + """Show available models endpoint that works for both LLM and diffusion modes. + + Returns model information in OpenAI-compatible format. + """ + # Check if we're in diffusion mode + diffusion_model_name = getattr(raw_request.app.state, "diffusion_model_name", None) + if diffusion_model_name is not None: + # Diffusion mode - return the loaded model + return JSONResponse( + content={ + "object": "list", + "data": [ + { + "id": diffusion_model_name, + "object": "model", + "created": 0, + "owned_by": "vllm-omni", + "permission": [], + } + ], + } + ) + + # LLM mode - delegate to openai_serving_models + openai_serving_models = getattr(raw_request.app.state, "openai_serving_models", None) + if openai_serving_models is not None: + models = await openai_serving_models.show_available_models() + return JSONResponse(content=models.model_dump()) + + return JSONResponse( + content={"object": "list", "data": []}, + ) + + +# Image generation API endpoints + + +@router.post( + "/v1/images/generations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": ImageGenerationResponse}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.SERVICE_UNAVAILABLE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def generate_images(request: ImageGenerationRequest, raw_request: Request) -> ImageGenerationResponse: + """Generate images from text prompts using diffusion models. + + OpenAI DALL-E compatible endpoint for text-to-image generation. + Only supports multi-stage omni mode with diffusion stages. + + Args: + request: Image generation request with prompt and parameters + raw_request: Raw FastAPI request for accessing app state + + Returns: + ImageGenerationResponse with generated images as base64 PNG + + Raises: + HTTPException: For validation errors, missing engine, or generation failures + """ + # Get engine client (AsyncOmni) from app state + engine_client, model_name, stage_types = _get_engine_and_model(raw_request) + + # Validate model field (warn if mismatch, don't error) + if request.model is not None and request.model != model_name: + logger.warning( + f"Model mismatch: request specifies '{request.model}' but " + f"server is running '{model_name}'. Using server model." + ) + + try: + # Build params - pass through user values directly + prompt: OmniTextPrompt = {"prompt": request.prompt} + if request.negative_prompt is not None: + prompt["negative_prompt"] = request.negative_prompt + gen_params = OmniDiffusionSamplingParams(num_outputs_per_prompt=request.n) + + # Parse per-request LoRA (compatible with chat's extra_body.lora shape). + lora_request, lora_scale = _parse_lora_request(request.lora) + _update_if_not_none(gen_params, "lora_request", lora_request) + _update_if_not_none(gen_params, "lora_scale", lora_scale) + + # Parse and add size if provided + width, height = None, None + if request.size: + width, height = parse_size(request.size) + size_str = f"{width}x{height}" + else: + size_str = "model default" + _update_if_not_none(gen_params, "width", width) + _update_if_not_none(gen_params, "height", height) + + # 3.3 Add optional parameters ONLY if provided + _update_if_not_none(gen_params, "num_inference_steps", request.num_inference_steps) + _update_if_not_none(gen_params, "guidance_scale", request.guidance_scale) + _update_if_not_none(gen_params, "true_cfg_scale", request.true_cfg_scale) + # If seed is not provided, generate a random one to ensure + # a proper generator is initialized in the backend. + # This fixes issues where using the default global generator + # might produce blurry images in some environments. + _update_if_not_none(gen_params, "seed", random.randint(0, 2**32 - 1) if request.seed is None else request.seed) + + request_id = f"img_gen_{uuid.uuid4().hex}" + + logger.info(f"Generating {request.n} image(s) {size_str}") + + # Generate images using AsyncOmni (multi-stage mode) + result = await _generate_with_async_omni( + engine_client=engine_client, + gen_params=gen_params, + stage_types=stage_types, + prompt=prompt, + request_id=request_id, + ) + + if result is None: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail="No output generated from multi-stage pipeline.", + ) + + # Extract images from result + images = _extract_images_from_result(result) + + logger.info(f"Successfully generated {len(images)} image(s)") + + # Encode images to base64 + image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in images] + + return ImageGenerationResponse( + created=int(time.time()), + data=image_data, + ) + + except HTTPException: + # Re-raise HTTPExceptions as-is + raise + except ValueError as e: + logger.error(f"Validation error: {e}") + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) + except Exception as e: + logger.exception(f"Image generation failed: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Image generation failed: {str(e)}" + ) + + +@router.post( + "/v1/images/edits", + responses={ + HTTPStatus.OK.value: {"model": ImageGenerationResponse}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.SERVICE_UNAVAILABLE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def edit_images( + raw_request: Request, + image: list[UploadFile] | None = File(None), + image_array: list[UploadFile] | None = File(None, alias="image[]"), + url: list[str] | None = Form(None), + url_array: list[str] | None = Form(None, alias="url[]"), + prompt: str = Form(...), + model: str = Form(None), + n: int = Form(1), + size: str = Form("auto"), + response_format: str = Form("b64_json"), + output_format: str | None = Form("png"), + background: str | None = Form("auto"), + output_compression: Annotated[int, Form(ge=0, le=100)] = 100, + user: str | None = Form(None), # unused now + # vllm-omni extensions for diffusion control + negative_prompt: str | None = Form(None), + num_inference_steps: int | None = Form(None), + guidance_scale: float | None = Form(None), + true_cfg_scale: float | None = Form(None), + seed: int | None = Form(None), + # vllm-omni extension for per-request LoRA. + lora: str | None = Form(None), # Json string +) -> ImageGenerationResponse: + """ + OpenAI-compatible image edit endpoint. + """ + # 1. get engine and model + engine_client, model_name, stage_types = _get_engine_and_model(raw_request) + if model is not None and model != model_name: + logger.warning( + f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model." + ) + # 2. get output format & compression + output_format = _choose_output_format(output_format, background) + if response_format != "b64_json": + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Only response_format 'b64_json' is supported now.", + ) + try: + # 2. Build prompt & images params + prompt: OmniTextPrompt = {"prompt": prompt} + if negative_prompt is not None: + prompt["negative_prompt"] = negative_prompt + input_images_list = [] + images = image or image_array + urls = url or url_array + if images: + input_images_list.extend(images) + if urls: + input_images_list.extend(urls) + if not input_images_list: + raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required") + pil_images = await _load_input_images(input_images_list) + prompt["multi_modal_data"] = {} + prompt["multi_modal_data"]["image"] = pil_images + + # 3 Build sample params + gen_params = OmniDiffusionSamplingParams() + # 3.0 Init with system default values + app_state_args = getattr(raw_request.app.state, "args", None) + default_sample_param = getattr(app_state_args, "default_sampling_params", None) + # Currently only have one diffusion stage + diffusion_stage_id = [i for i, t in enumerate(stage_types) if t == "diffusion"][0] + apply_stage_default_sampling_params( + default_sample_param, + gen_params, + str(diffusion_stage_id), + ) + _update_if_not_none(gen_params, "num_outputs_per_prompt", n) + # 3.1 Parse per-request LoRA (compatible with chat's extra_body.lora shape). + lora_dict = _get_lora_from_json_str(lora) + lora_request, lora_scale = _parse_lora_request(lora_dict) + _update_if_not_none(gen_params, "lora_request", lora_request) + _update_if_not_none(gen_params, "lora_scale", lora_scale) + # 3.2 Parse and add size if provided + max_generated_image_size = getattr(app_state_args, "max_generated_image_size", None) + width, height = None, None + if size.lower() == "auto": + width, height = pil_images[0].size # Use first image size + else: + width, height = parse_size(size) + if max_generated_image_size is not None and (width * height > max_generated_image_size): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"Requested image size {width}x{height} exceeds the maximum allowed " + f"size of {max_generated_image_size} pixels.", + ) + + size_str = f"{width}x{height}" + _update_if_not_none(gen_params, "width", width) + _update_if_not_none(gen_params, "height", height) + + # 3.3 Add optional parameters ONLY if provided + _update_if_not_none(gen_params, "num_inference_steps", num_inference_steps) + _update_if_not_none(gen_params, "guidance_scale", guidance_scale) + _update_if_not_none(gen_params, "true_cfg_scale", true_cfg_scale) + # If seed is not provided, generate a random one to ensure + # a proper generator is initialized in the backend. + # This fixes issues where using the default global generator + # might produce blurry images in some environments. + _update_if_not_none(gen_params, "seed", seed or random.randint(0, 2**32 - 1)) + + # 4. Generate images using AsyncOmni (multi-stage mode) + request_id = f"img_edit_{int(time.time())}" + logger.info(f"Generating {n} image(s) {size_str}") + result = await _generate_with_async_omni( + engine_client=engine_client, + gen_params=gen_params, + stage_types=stage_types, + prompt=prompt, + request_id=request_id, + ) + + # 5. Extract images from result + images = _extract_images_from_result(result) + logger.info(f"Successfully generated {len(images)} image(s)") + + # Encode images to base64 + image_data = [ + ImageData( + b64_json=_encode_image_base64_with_compression( + img, format=output_format, output_compression=output_compression + ), + revised_prompt=None, + ) + for img in images + ] + + return ImageGenerationResponse( + created=int(time.time()), + data=image_data, + output_format=output_format, + size=size_str, + ) + + except HTTPException: + # Re-raise HTTPExceptions as-is + raise + except ValueError as e: + logger.error(f"Validation error: {e}") + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) + except Exception as e: + logger.exception(f"Image edit failed: {e}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Image edit failed: {str(e)}") + + +def _get_engine_and_model(raw_request: Request): + # Get engine client (AsyncOmni) from app state + engine_client: EngineClient | AsyncOmni | None = getattr(raw_request.app.state, "engine_client", None) + if engine_client is None or not hasattr(engine_client, "stage_list"): + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="Multi-stage engine not initialized. Start server with a multi-stage omni model.", + ) + + # Check if there's a diffusion stage + stage_configs = getattr(raw_request.app.state, "stage_configs", None) + if not stage_configs: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="Stage configs not found. Start server with a multi-stage omni model.", + ) + + # Check for diffusion stage and collect stage types + has_diffusion_stage = False + stage_types: list[str] = [] + for stage in stage_configs: + # Handle both dict and OmegaConf objects + stage_type = None + if isinstance(stage, dict): + stage_type = stage.get("stage_type", "llm") + elif hasattr(stage, "get"): + stage_type = stage.get("stage_type", "llm") + elif hasattr(stage, "stage_type"): + stage_type = stage.stage_type + else: + # Fallback: try to access as dict-like + try: + stage_type = stage["stage_type"] if "stage_type" in stage else "llm" + except (TypeError, KeyError): + stage_type = "llm" + + if stage_type == "diffusion": + has_diffusion_stage = True + stage_types.append(stage_type) + + if not has_diffusion_stage: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="No diffusion stage found in multi-stage pipeline.", + ) + + # Get server's loaded model name + serving_models = getattr(raw_request.app.state, "openai_serving_models", None) + if serving_models and hasattr(serving_models, "base_model_paths") and serving_models.base_model_paths: + model_name = serving_models.base_model_paths[0].name + else: + model_name = "unknown" + + return engine_client, model_name, stage_types + + +def _get_lora_from_json_str(lora_body): + if lora_body is None: + return None + try: + lora_dict = json.loads(lora_body) + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="Invalid LoRA JSON string") + + if not isinstance(lora_dict, dict): + raise HTTPException(status_code=400, detail="LoRA must be a JSON object") + + return lora_dict + + +def _parse_lora_request(lora_body: dict[str, Any]): + if lora_body is not None: + if not isinstance(lora_body, dict): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Invalid lora field: expected an object.", + ) + lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") + lora_path = ( + lora_body.get("local_path") + or lora_body.get("path") + or lora_body.get("lora_path") + or lora_body.get("lora_local_path") + ) + lora_scale = lora_body.get("scale") + if lora_scale is None: + lora_scale = lora_body.get("lora_scale") + lora_int_id = lora_body.get("int_id") + if lora_int_id is None: + lora_int_id = lora_body.get("lora_int_id") + if lora_int_id is None and lora_path: + lora_int_id = stable_lora_int_id(str(lora_path)) + + if not lora_name or not lora_path: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Invalid lora object: both name and path are required.", + ) + + return LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)), lora_scale + return None, None + + +async def _generate_with_async_omni( + engine_client: AsyncOmni | Any, + gen_params: Any, + stage_types: list[str], + **kwargs, +): + engine_client = cast(AsyncOmni, engine_client) + result = None + stage_list = getattr(engine_client, "stage_list", None) + if isinstance(stage_list, list): + default_params_list: list[OmniSamplingParams] | None = getattr( + engine_client, "default_sampling_params_list", None + ) + if not isinstance(default_params_list, list): + default_params_list = [ + OmniDiffusionSamplingParams() if st == "diffusion" else SamplingParams() for st in stage_types + ] + else: + default_params_list = list(default_params_list) + if len(default_params_list) != len(stage_types): + default_params_list = ( + default_params_list + + [OmniDiffusionSamplingParams() if st == "diffusion" else SamplingParams() for st in stage_types] + )[: len(stage_types)] + + sampling_params_list: list[OmniSamplingParams] = [] + for idx, stage_type in enumerate(stage_types): + if stage_type == "diffusion": + sampling_params_list.append(gen_params) + else: + base_params = default_params_list[idx] + sampling_params_list.append(base_params) + + async for output in engine_client.generate( + sampling_params_list=sampling_params_list, + **kwargs, + ): + result = output + else: + result = await engine_client.generate( + sampling_params_list=[gen_params], + **kwargs, + ) + + if result is None: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail="No output generated from multi-stage pipeline.", + ) + return result + + +def _update_if_not_none(object: any, key: str, val: any) -> None: + if val is not None: + setattr(object, key, val) + + +def _extract_images_from_result(result: Any) -> list[Any]: + images = [] + if hasattr(result, "images") and result.images: + images = result.images + elif hasattr(result, "request_output"): + request_output = result.request_output + if isinstance(request_output, dict) and request_output.get("images"): + images = request_output["images"] + elif hasattr(request_output, "images") and request_output.images: + images = request_output.images + return images + + +async def _load_input_images( + inputs: list[str], +) -> list[Image.Image]: + """ + convert to PIL.Image.Image list + """ + if isinstance(inputs, str): + inputs = [inputs] + + images: list[Image.Image] = [] + + for inp in inputs: + # 1. URL + base64 + if isinstance(inp, str) and inp.startswith("data:image"): + try: + _, b64_data = inp.split(",", 1) + image_bytes = base64.b64decode(b64_data) + img = Image.open(io.BytesIO(image_bytes)) + images.append(img) + except Exception as e: + raise ValueError(f"Invalid base64 image: {e}") + + # 2. URL + elif isinstance(inp, str) and inp.startswith("http"): + async with httpx.AsyncClient(timeout=60) as client: + try: + resp = await client.get(inp) + resp.raise_for_status() + img = Image.open(io.BytesIO(resp.content)) + images.append(img) + except Exception as e: + raise ValueError(f"Failed to download image from URL {inp}: {e}") + + # 3. UploadFile + elif hasattr(inp, "file"): + try: + img_data = await inp.read() + img = Image.open(io.BytesIO(img_data)) + images.append(img) + except Exception as e: + raise ValueError(f"Failed to open uploaded file: {e}") + else: + raise ValueError(f"Unsupported input: {inp}") + + if not images: + raise ValueError("No valid input images found") + + return images + + +def _choose_output_format(output_format: str | None, background: str | None) -> str: + # Normalize and choose extension + fmt = (output_format or "").lower() + if fmt in {"jpg", "png", "webp", "jpeg"}: + return fmt + # If transparency requested, prefer png + if (background or "auto").lower() == "transparent": + return "png" + # Default + return "jpeg" + + +def _encode_image_base64_with_compression( + image: Image.Image, format: str = "png", output_compression: int = 100 +) -> str: + """Encode PIL Image to base64 PNG string. + + Args: + image: PIL Image object + format: Output image format (e.g., "PNG", "JPEG", "WEBP") + output_compression: Compression level (0-100%), 100 for best quality + Returns: + Base64-encoded image as string + """ + buffer = io.BytesIO() + save_kwargs = {} + if format in ("jpg", "jpeg", "webp"): + save_kwargs["quality"] = output_compression + elif format == "png": + save_kwargs["compress_level"] = max(0, min(9, 9 - output_compression // 11)) # Map 0-100 to 9-0 + + image.save(buffer, format=format, **save_kwargs) + buffer.seek(0) + return base64.b64encode(buffer.read()).decode("utf-8") + + +def apply_stage_default_sampling_params( + default_params_json: str | None, + sampling_params: any, + stage_key: str, +) -> None: + """ + Update a stage's sampling parameters with vLLM-Omni defaults. + + Args: + default_params_json: JSON string of stage-keyed default parameters + sampling_params: The sampling parameters object to update + stage_key: The stage ID/key in the pipeline + """ + if default_params_json is not None: + default_params_dict = json.loads(default_params_json) + if stage_key in default_params_dict: + stage_defaults = default_params_dict[stage_key] + for param_name, param_value in stage_defaults.items(): + if hasattr(sampling_params, param_name): + setattr(sampling_params, param_name, param_value) diff --git a/vllm_omni/entrypoints/openai/audio_utils_mixin.py b/vllm_omni/entrypoints/openai/audio_utils_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..13df32ebe00dfc4f63acf0ae47135fe090d0db7e --- /dev/null +++ b/vllm_omni/entrypoints/openai/audio_utils_mixin.py @@ -0,0 +1,93 @@ +from io import BytesIO + +import numpy as np +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio + +try: + import soundfile +except ImportError: + soundfile = None + +try: + import librosa +except ImportError: + librosa = None + +logger = init_logger(__name__) + + +class AudioMixin: + """Mixin class to add audio-related utilities.""" + + def create_audio(self, audio_obj: CreateAudio) -> AudioResponse: + """Convert audio tensor to bytes in the specified format.""" + + audio_tensor = audio_obj.audio_tensor + sample_rate = audio_obj.sample_rate + response_format = audio_obj.response_format.lower() + stream_format = audio_obj.stream_format + base64_encode = audio_obj.base64_encode + speed = audio_obj.speed + + if stream_format != "audio": + raise ValueError(f"Unsupported stream format: {stream_format}") + + if soundfile is None: + raise ImportError( + "soundfile is required for audio generation. Please install it with: pip install soundfile" + ) + + if audio_tensor.ndim > 2: + raise ValueError( + f"Unsupported audio tensor dimension: {audio_tensor.ndim}. " + "Only mono (1D) and stereo (2D) are supported." + ) + + audio_tensor, sample_rate = self._apply_speed_adjustment(audio_tensor, speed, sample_rate) + + supported_formats = { + "wav": ("WAV", "audio/wav", {}), + "pcm": ("RAW", "audio/pcm", {"subtype": "PCM_16"}), + "flac": ("FLAC", "audio/flac", {}), + "mp3": ("MP3", "audio/mpeg", {}), + "aac": ("AAC", "audio/aac", {}), + "opus": ("OGG", "audio/ogg", {"subtype": "OPUS"}), + } + + if response_format not in supported_formats: + logger.warning(f"Unsupported response format '{response_format}', defaulting to 'wav'.") + response_format = "wav" + + soundfile_format, media_type, kwargs = supported_formats[response_format] + + with BytesIO() as buffer: + soundfile.write(buffer, audio_tensor, sample_rate, format=soundfile_format, **kwargs) + audio_data = buffer.getvalue() + + if base64_encode: + import base64 + + audio_data = base64.b64encode(audio_data).decode("utf-8") + + return AudioResponse(audio_data=audio_data, media_type=media_type) + + def _apply_speed_adjustment(self, audio_tensor: np.ndarray, speed: float, sample_rate: int): + """Apply speed adjustment to the audio tensor while preserving pitch.""" + if speed == 1.0: + return audio_tensor, sample_rate + + if librosa is None: + raise ImportError("librosa is required for speed adjustment. Please install it with: pip install librosa") + + try: + # librosa.effects.time_stretch requires a float audio tensor. + if not np.issubdtype(audio_tensor.dtype, np.floating): + audio_tensor = audio_tensor.astype(np.float32) + + stretched_audio = librosa.effects.time_stretch(y=audio_tensor, rate=speed) + return stretched_audio, sample_rate + except Exception as e: + logger.error(f"An error occurred during speed adjustment: {e}") + raise ValueError("Failed to apply speed adjustment.") from e diff --git a/vllm_omni/entrypoints/openai/image_api_utils.py b/vllm_omni/entrypoints/openai/image_api_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9d8fa52464de6788e3ed410326aea739e5c486 --- /dev/null +++ b/vllm_omni/entrypoints/openai/image_api_utils.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared helper utilities for OpenAI-compatible image generation API. + +This module provides common helper functions for the image generation endpoint. +All functions work with plain Python types to maintain separation from the +FastAPI HTTP layer. +""" + +import base64 +import io + +import PIL.Image + + +def parse_size(size_str: str) -> tuple[int, int]: + """Parse size string to width and height tuple. + + Args: + size_str: Size in format "WIDTHxHEIGHT" (e.g., "1024x1024") + + Returns: + Tuple of (width, height) + + Raises: + ValueError: If size format is invalid + """ + if not size_str or not isinstance(size_str, str): + raise ValueError( + f"Size must be a non-empty string in format 'WIDTHxHEIGHT' (e.g., '1024x1024'), got: {size_str}" + ) + + parts = size_str.split("x") + if len(parts) != 2: + raise ValueError( + f"Invalid size format: '{size_str}'. Expected format: 'WIDTHxHEIGHT' (e.g., '1024x1024'). " + f"Did you mean to use 'x' as separator?" + ) + + try: + width = int(parts[0]) + height = int(parts[1]) + except ValueError: + raise ValueError(f"Invalid size format: '{size_str}'. Width and height must be integers.") + + if width <= 0 or height <= 0: + raise ValueError(f"Invalid size: {width}x{height}. Width and height must be positive integers.") + + return width, height + + +def encode_image_base64(image: PIL.Image.Image) -> str: + """Encode PIL Image to base64 PNG string. + + Args: + image: PIL Image object + + Returns: + Base64-encoded PNG image as string + """ + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + return base64.b64encode(buffer.read()).decode("utf-8") diff --git a/vllm_omni/entrypoints/openai/protocol/__init__.py b/vllm_omni/entrypoints/openai/protocol/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da65e1817fe267e7d80e655071b07e2dc66c0d46 --- /dev/null +++ b/vllm_omni/entrypoints/openai/protocol/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.entrypoints.openai.protocol.chat_completion import OmniChatCompletionStreamResponse +from vllm_omni.entrypoints.openai.protocol.images import ( + ImageData, + ImageGenerationRequest, + ImageGenerationResponse, + ResponseFormat, +) + +__all__ = [ + "ImageData", + "ImageGenerationRequest", + "ImageGenerationResponse", + "ResponseFormat", + "OmniChatCompletionStreamResponse", +] diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..d23460626b188cf264ccc747f7cbfc979f28c3d9 --- /dev/null +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -0,0 +1,74 @@ +from typing import Literal + +import numpy as np +from pydantic import BaseModel, Field, field_validator + + +class OpenAICreateSpeechRequest(BaseModel): + input: str + model: str | None = None + voice: str | None = Field( + default=None, + description="Voice to use. For OpenAI: alloy, echo, etc. For Qwen3-TTS: Vivian, Ryan, etc.", + ) + instructions: str | None = Field( + default=None, + description="Instructions for voice style/emotion (maps to 'instruct' for Qwen3-TTS)", + ) + response_format: Literal["wav", "pcm", "flac", "mp3", "aac", "opus"] = "wav" + speed: float | None = Field( + default=1.0, + ge=0.25, + le=4.0, + ) + stream_format: Literal["sse", "audio"] | None = "audio" + + # Qwen3-TTS specific parameters + task_type: Literal["CustomVoice", "VoiceDesign", "Base"] | None = Field( + default=None, + description="TTS task type: CustomVoice, VoiceDesign, or Base (voice clone)", + ) + language: str | None = Field( + default=None, + description="Language code (e.g., 'Chinese', 'English', 'Auto')", + ) + ref_audio: str | None = Field( + default=None, + description="Reference audio for voice cloning (Base task). URL, base64, or file path.", + ) + ref_text: str | None = Field( + default=None, + description="Transcript of reference audio for voice cloning (Base task)", + ) + x_vector_only_mode: bool | None = Field( + default=None, + description="Use speaker embedding only without in-context learning (Base task)", + ) + max_new_tokens: int | None = Field( + default=None, + description="Maximum tokens to generate", + ) + + @field_validator("stream_format") + @classmethod + def validate_stream_format(cls, v: str) -> str: + if v == "sse": + raise ValueError("'sse' is not a supported stream_format yet. Please use 'audio'.") + return v + + +class CreateAudio(BaseModel): + audio_tensor: np.ndarray + sample_rate: int = 24000 + response_format: str = "wav" + speed: float = 1.0 + stream_format: Literal["sse", "audio"] | None = "audio" + base64_encode: bool = True + + class Config: + arbitrary_types_allowed = True + + +class AudioResponse(BaseModel): + audio_data: bytes | str + media_type: str diff --git a/vllm_omni/entrypoints/openai/protocol/chat_completion.py b/vllm_omni/entrypoints/openai/protocol/chat_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc93f672746f7920663101ef2844f09c0e8b4be --- /dev/null +++ b/vllm_omni/entrypoints/openai/protocol/chat_completion.py @@ -0,0 +1,5 @@ +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionStreamResponse + + +class OmniChatCompletionStreamResponse(ChatCompletionStreamResponse): + modality: str | None = "text" diff --git a/vllm_omni/entrypoints/openai/protocol/images.py b/vllm_omni/entrypoints/openai/protocol/images.py new file mode 100644 index 0000000000000000000000000000000000000000..63514b191e8e4eb5f2b88b17de18e66a14fea6c1 --- /dev/null +++ b/vllm_omni/entrypoints/openai/protocol/images.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +OpenAI-compatible protocol definitions for image generation. + +This module provides Pydantic models that follow the OpenAI DALL-E API specification +for text-to-image generation, with vllm-omni specific extensions. +""" + +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field, field_validator + + +class ResponseFormat(str, Enum): + """Image response format""" + + B64_JSON = "b64_json" + URL = "url" # Not implemented in PoC + + +class ImageGenerationRequest(BaseModel): + """ + OpenAI DALL-E compatible image generation request. + + Follows the OpenAI Images API specification with vllm-omni extensions + for advanced diffusion parameters. + """ + + # Required fields + prompt: str = Field(..., description="Text description of the desired image(s)") + + # OpenAI standard fields + model: str | None = Field( + default=None, + description="Model to use (optional, uses server's configured model if omitted)", + ) + n: int = Field(default=1, ge=1, le=10, description="Number of images to generate") + size: str | None = Field( + default=None, + description="Image dimensions in WIDTHxHEIGHT format (e.g., '1024x1024', uses model defaults if omitted)", + ) + response_format: ResponseFormat = Field(default=ResponseFormat.B64_JSON, description="Format of the returned image") + user: str | None = Field(default=None, description="User identifier for tracking") + + @field_validator("size") + @classmethod + def validate_size(cls, v): + """Validate size parameter. + + Accepts any string in 'WIDTHxHEIGHT' format (e.g., '1024x1024', '512x768'). + No restrictions on specific dimensions - models can handle arbitrary sizes. + """ + if v is None: + return None + # Validate string format + if not isinstance(v, str) or "x" not in v: + raise ValueError("size must be in format 'WIDTHxHEIGHT' (e.g., '1024x1024')") + return v + + @field_validator("response_format") + @classmethod + def validate_response_format(cls, v): + """Validate response format - only b64_json is supported.""" + if v is not None and v != ResponseFormat.B64_JSON: + raise ValueError(f"Only 'b64_json' response format is supported, got: {v}") + return v + + # vllm-omni extensions for diffusion control + negative_prompt: str | None = Field(default=None, description="Text describing what to avoid in the image") + num_inference_steps: int | None = Field( + default=None, + ge=1, + le=200, + description="Number of diffusion sampling steps (uses model defaults if not specified)", + ) + guidance_scale: float | None = Field( + default=None, + ge=0.0, + le=20.0, + description="Classifier-free guidance scale (uses model defaults if not specified)", + ) + true_cfg_scale: float | None = Field( + default=None, + ge=0.0, + le=20.0, + description="True CFG scale (model-specific parameter, may be ignored if not supported)", + ) + seed: int | None = Field(default=None, description="Random seed for reproducibility") + + # vllm-omni extension for per-request LoRA. + # This mirrors the `extra_body.lora` convention in /v1/chat/completions. + lora: dict[str, Any] | None = Field( + default=None, + description=( + "Optional LoRA adapter for this request. Expected shape: " + "{name/path/scale/int_id}. Field names are flexible " + "(e.g. name|lora_name|adapter, path|lora_path|local_path, " + "scale|lora_scale, int_id|lora_int_id)." + ), + ) + + # VAE memory optimizations (set at model init, included for completeness) + vae_use_slicing: bool | None = Field(default=False, description="Enable VAE slicing") + vae_use_tiling: bool | None = Field(default=False, description="Enable VAE tiling") + + +class ImageData(BaseModel): + """Single generated image data""" + + b64_json: str | None = Field(default=None, description="Base64-encoded PNG image") + url: str | None = Field(default=None, description="Image URL (not implemented)") + revised_prompt: str | None = Field(default=None, description="Revised prompt (OpenAI compatibility, always null)") + + +class ImageGenerationResponse(BaseModel): + """ + OpenAI DALL-E compatible image generation response. + + Returns generated images with metadata. + """ + + created: int = Field(..., description="Unix timestamp of when the generation completed") + data: list[ImageData] = Field(..., description="Array of generated images") + output_format: str = Field(None, description="The output format of the image generation") + size: str = Field(None, description="The size of the image generated") diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..cfee79157e185128e02e446af6f31beb543d5b13 --- /dev/null +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -0,0 +1,2169 @@ +import asyncio +import base64 +import json +import time +import uuid +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence +from datetime import datetime, timedelta, timezone +from io import BytesIO +from typing import TYPE_CHECKING, Any, Final, Optional, cast + +import jinja2 +import torch +from fastapi import Request +from PIL import Image +from pydantic import TypeAdapter +from vllm.renderers import RendererLike + +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt + +try: + import soundfile +except ImportError: + soundfile = None + + +from openai.types.chat.chat_completion_audio import ChatCompletionAudio as OpenAIChatCompletionAudio +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + get_history_tool_calls_cnt, + make_tool_call_id, +) +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, +) +from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ErrorInfo, + ErrorResponse, + FunctionCall, + FunctionDefinition, + PromptTokenUsageInfo, + RequestResponseMetadata, + ToolCall, + UsageInfo, +) +from vllm.entrypoints.openai.engine.serving import ChatLikeRequest, clamp_prompt_logprobs +from vllm.entrypoints.openai.parser.harmony_utils import ( + get_streamable_parser_for_assistant, + parse_chat_output, +) +from vllm.entrypoints.openai.responses.protocol import ResponsesRequest +from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls +from vllm.entrypoints.utils import should_include_usage +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.renderers.hf import ( + resolve_chat_template_content_format, +) +from vllm.renderers.hf import ( + safe_apply_chat_template as apply_hf_chat_template, +) +from vllm.renderers.mistral import ( + safe_apply_chat_template as apply_mistral_chat_template, +) +from vllm.sampling_params import SamplingParams +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import ( + MistralTokenizer, + maybe_serialize_tool_calls, + truncate_tool_call_ids, + validate_request_params, +) +from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.collection_utils import as_list, is_list_of + +from vllm_omni.entrypoints.chat_utils import parse_chat_messages_futures +from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin +from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse +from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id +from vllm_omni.outputs import OmniRequestOutput + +if TYPE_CHECKING: + from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion + +logger = init_logger(__name__) + + +class OmniOpenAIServingChat(OpenAIServingChat, AudioMixin): + """OpenAI-compatible chat serving for both LLM and Diffusion models. + + This class extends OpenAIServingChat to support: + - Standard LLM chat completions + - Diffusion model image generation via chat interface + + For diffusion mode, use the `for_diffusion` class method to create an instance. + """ + + # Diffusion mode attributes + _diffusion_mode: bool = False + _diffusion_engine: Optional["AsyncOmniDiffusion"] = None + _diffusion_model_name: str = "" + + @classmethod + def for_diffusion( + cls, + diffusion_engine: "AsyncOmniDiffusion", + model_name: str, + ) -> "OmniOpenAIServingChat": + """Create a chat serving instance for diffusion models. + + Args: + diffusion_engine: The async diffusion engine + model_name: Name of the model being served + + Returns: + OmniOpenAIServingChat instance configured for diffusion mode + + Note: + Request-level parameters (num_inference_steps, guidance_scale, seed, + height, width, num_frames, fps, etc.) are passed per-request via the API. + """ + instance = cls.__new__(cls) + instance._diffusion_mode = True + instance._diffusion_engine = diffusion_engine + instance._diffusion_model_name = model_name + return instance + + async def create_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse: + """ + Chat Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI + Chat Completion API. + + For diffusion models, this generates images and returns them + in a chat completion response format. + """ + # Handle diffusion mode + if self._diffusion_mode: + return await self._create_diffusion_chat_completion(request, raw_request) + + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + try: + lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True) + + model_name = self.models.model_name(lora_request) + + renderer = self.renderer + tokenizer = renderer.get_tokenizer() + if tokenizer is None: + tokenizer = await self.engine_client.get_tokenizer() + + tool_parser = self.tool_parser + + if isinstance(tokenizer, MistralTokenizer): + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` + maybe_serialize_tool_calls(request) + truncate_tool_call_ids(request) + validate_request_params(request) + + # Check if tool parsing is unavailable (common condition) + tool_parsing_unavailable = ( + tool_parser is None and not isinstance(tokenizer, MistralTokenizer) and not self.use_harmony + ) + + # Validate tool_choice when tool parsing is required but unavailable + if tool_parsing_unavailable and request.tool_choice not in ( + None, + "none", + ): + if request.tool_choice == "auto" and not self.enable_auto_tools: + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + '"auto" tool choice requires --enable-auto-tool-choice and --tool-call-parser to be set' + ) + elif request.tool_choice != "auto": + # "required" or named tool requires tool parser + return self.create_error_response( + f'tool_choice="{request.tool_choice}" requires --tool-call-parser to be set' + ) + + if request.tools is None or (request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none): + tool_dicts = None + else: + tool_dicts = [tool.model_dump() for tool in request.tools] + + if not self.use_harmony: + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + + chat_template_kwargs = request.chat_template_kwargs or {} + chat_template_kwargs.update(reasoning_effort=request.reasoning_effort) + + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + renderer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=getattr(request, "documents", None), + chat_template_kwargs=chat_template_kwargs, + default_chat_template_kwargs=self.default_chat_template_kwargs, + tool_parser=tool_parser, + add_special_tokens=request.add_special_tokens, + ) + else: + should_include_tools = tool_dicts is not None + conversation, engine_prompts = self._make_request_with_harmony(request, should_include_tools) + request_prompts = [engine_prompt.get("prompt_token_ids", []) for engine_prompt in engine_prompts] + + except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(f"{e} {e.__cause__}") + + request_id = f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + output_modalities = getattr(request, "modalities", self.engine_client.output_modalities) + request.modalities = ( + output_modalities if output_modalities is not None else self.engine_client.output_modalities + ) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + if hasattr(request, "sampling_params_list"): + sampling_params_list = self._to_sampling_params_list(request.sampling_params_list) + else: + # Use standard OpenAI API parameters for comprehension stage + sampling_params_list = self._build_sampling_params_list_from_request(request) + + self._log_inputs( + request_id, + request_prompts[i], + params_list=sampling_params_list, + lora_request=lora_request, + ) + + generator = self.engine_client.generate( + prompt=engine_prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=output_modalities, + ) + + generators.append(generator) + except ValueError as e: + return self.create_error_response(e) + + assert len(generators) == 1 + (result_generator,) = generators + + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + ) + + try: + return await self.chat_completion_full_generator( + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + ) + except ValueError as e: + return self.create_error_response(e) + + async def _preprocess_chat( + self, + request: ChatLikeRequest | ResponsesRequest, + renderer: RendererLike, + messages: list[ChatCompletionMessageParam], + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: list[dict[str, Any]] | None = None, + documents: list[dict[str, str]] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + default_chat_template_kwargs: dict[str, Any] | None = None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + add_special_tokens: bool = False, + ) -> tuple[ + list[ConversationMessage], + Sequence[PromptType], + list[TokensPrompt], + ]: + model_config = self.model_config + tokenizer = renderer.get_tokenizer() if renderer is not None else None + + if tokenizer is None or isinstance(tokenizer, MistralTokenizer): + resolved_content_format = ( + chat_template_content_format if chat_template_content_format != "auto" else "string" + ) + else: + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tool_dicts, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + # OMNI: Updated for vLLM v0.15.0 API - resolve_items() returns (mm_data, mm_uuids) tuple + conversation, mm_future = parse_chat_messages_futures( + messages, + model_config, + content_format=resolved_content_format, + mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), + ) + + # Merge default_chat_template_kwargs with request-provided kwargs + # Request kwargs take precedence over defaults + merged_kwargs = self._prepare_extra_chat_template_kwargs( + chat_template_kwargs, + default_chat_template_kwargs, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + _chat_template_kwargs.update(merged_kwargs) + + request_prompt: str | list[int] + + if tokenizer is None: + request_prompt = "placeholder" + elif isinstance(tokenizer, MistralTokenizer): + request_prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + **_chat_template_kwargs, + ) + else: + hf_chat_template_kwargs = dict(_chat_template_kwargs) + hf_chat_template_kwargs.pop("tools", None) + hf_chat_template_kwargs.pop("chat_template", None) + request_prompt = apply_hf_chat_template( + model_config=model_config, + tokenizer=tokenizer, + conversation=conversation, + tools=tool_dicts, + chat_template=chat_template, + tokenize=False, + **hf_chat_template_kwargs, + ) + + # OMNI: Await the combined future to get both mm_data and mm_uuids + mm_data, mm_uuids = await mm_future + + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + should_parse_tools = tool_parser is not None and ( + hasattr(request, "tool_choice") and request.tool_choice != "none" + ) + + if should_parse_tools: + if not isinstance(request, ChatCompletionRequest): + msg = "Tool usage is only supported for Chat Completions API" + raise NotImplementedError(msg) + + request = tool_parser(tokenizer).adjust_request( # type: ignore + request=request + ) + + if tokenizer is None: + assert isinstance(request_prompt, str), ( + "Prompt has to be a string", + "when the tokenizer is not initialised", + ) + prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) + elif isinstance(request_prompt, str): + prompt_inputs = await self._tokenize_prompt_input_async( + request, + tokenizer, + request_prompt, + add_special_tokens=add_special_tokens, + ) + else: + # For MistralTokenizer + assert is_list_of(request_prompt, int), "Prompt has to be either a string or a list of token ids" + prompt_inputs = TokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt, + ) + + engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + engine_prompt["multi_modal_uuids"] = mm_uuids + + mm_processor_kwargs = getattr(request, "mm_processor_kwargs", None) + if mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs + + if hasattr(request, "cache_salt") and request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + return conversation, [request_prompt], [engine_prompt] + + def _to_sampling_params_list(self, sampling_params_list: list[dict]) -> list[SamplingParams]: + final_sampling_params_list = [] + for sampling_params in sampling_params_list: + if isinstance(sampling_params, dict): + final_sampling_params_list.append(SamplingParams(**sampling_params)) + elif isinstance(sampling_params, SamplingParams): + final_sampling_params_list.append(sampling_params) + else: + raise ValueError(f"Invalid sampling params: {sampling_params}") + return final_sampling_params_list + + def _get_comprehension_stage_index(self) -> int: + for idx, stage in enumerate(self.engine_client.stage_list): + if stage.is_comprehension: + return idx + raise ValueError("No comprehension stage (is_comprehension=True) found in stage_list") + + # OpenAI API standard sampling parameters that can be safely overridden. + # These are the most commonly used parameters with compatible types + # between ChatCompletionRequest and SamplingParams. + # Users who need more control can use sampling_params_list in extra_body. + _OPENAI_SAMPLING_FIELDS: set[str] = { + "temperature", + "top_p", + "max_tokens", + "seed", + "stop", + "frequency_penalty", + "presence_penalty", + } + + def _apply_request_overrides( + self, + default_params: SamplingParams, + request: ChatCompletionRequest, + ) -> SamplingParams: + """Clone default params and override with user-provided request values. + + Starts with YAML defaults and only overrides fields that the user + explicitly provided (non-None values) in the request. + + Args: + default_params: Default SamplingParams from stage config YAML. + request: The chat completion request containing user-provided values. + + Returns: + New SamplingParams with YAML defaults overridden by request values. + """ + params = default_params.clone() + + for field_name in self._OPENAI_SAMPLING_FIELDS: + value = getattr(request, field_name, None) + if value is not None: + setattr(params, field_name, value) + + return params + + def _build_sampling_params_list_from_request( + self, + request: ChatCompletionRequest, + ) -> list[SamplingParams]: + """Build sampling_params_list using standard OpenAI API parameters. + + For the comprehension stage, starts with YAML defaults and overrides with + user-provided request values. For other stages, uses cloned YAML defaults. + + This approach ensures all YAML defaults (including seed, detokenize, etc.) + are preserved while allowing users to override specific parameters. + + Args: + request: The chat completion request containing OpenAI API parameters. + + Returns: + List of SamplingParams, one for each stage. + """ + default_params_list = self.engine_client.default_sampling_params_list + comprehension_idx = self._get_comprehension_stage_index() + + sampling_params_list = [] + for idx, default_params in enumerate(default_params_list): + if isinstance(default_params, dict): + default_params = SamplingParams(**default_params) + if idx == comprehension_idx: + params = self._apply_request_overrides(default_params, request) + sampling_params_list.append(params) + else: + # For other stages, clone default params + sampling_params_list.append(default_params.clone()) + + return sampling_params_list + + def _log_inputs( + self, + request_id: str, + inputs: PromptType, + params_list: list[SamplingParams] | None, + lora_request: LoRARequest | None, + ) -> None: + if self.request_logger is None: + return + prompt, prompt_token_ids, prompt_embeds = None, None, None + if isinstance(inputs, str): + prompt = inputs + elif isinstance(inputs, list): + prompt_token_ids = inputs + else: + prompt = getattr(inputs, "prompt", None) + prompt_token_ids = getattr(inputs, "prompt_token_ids", None) + + logger.info( + "Received request %s: prompt: %r, params_list: %s, prompt_token_ids: %s, prompt_embeds shape: %s, lora_request: %s.", # noqa: E501 + request_id, + prompt, + params_list, + prompt_token_ids, + prompt_embeds.shape if prompt_embeds is not None else None, + lora_request, + ) + + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ): + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration_dict = {} + assert hasattr(request, "modalities") and request.modalities is not None, ( + "Streaming request must specify output modalities" + ) + for modality in request.modalities: + first_iteration_dict[modality] = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + num_cached_tokens = None + if self.use_harmony: + harmony_parsers = [get_streamable_parser_for_assistant() for _ in range(num_choices)] + harmony_tools_streamed = [False] * num_choices + tools_streamed = [False] * num_choices + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = not tool_choice_function_name and self._should_stream_with_auto_tool_parsing(request) + + all_previous_token_ids: list[list[int]] | None + function_name_returned = [False] * num_choices + if self.tool_call_id_type == "kimi_k2": + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 + + # Always track previous_texts for comprehensive output logging + previous_texts = [""] * num_choices + + # Only one of these will be used, thus previous_texts and + # all_previous_token_ids will not be used twice in the same iteration. + if tool_choice_auto or self.reasoning_parser: + # These are only required in "auto" tool choice case + all_previous_token_ids = [[]] * num_choices + # For reasoning parser and tool call all enabled + added_content_delta_arr = [False] * num_choices + reasoning_end_arr = [False] * num_choices + else: + all_previous_token_ids = None + + try: + if self.reasoning_parser: + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + reasoning_parser = self.reasoning_parser( + tokenizer, + chat_template_kwargs=chat_template_kwargs, # type: ignore + ) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + data = self.create_streaming_error_response(e) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: list[ToolParser | None] = [self.tool_parser(tokenizer)] * num_choices + else: + tool_parsers = [None] * num_choices + except Exception as e: + logger.exception("Error in tool parser creation.") + data = self.create_streaming_error_response(e) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + stream_options = request.stream_options + include_usage, include_continuous_usage = should_include_usage(stream_options, self.enable_force_include_usage) + + try: + async for omni_res in result_generator: + final_output_type = omni_res.final_output_type + res = omni_res.request_output + if final_output_type not in first_iteration_dict: + logger.warning(f"final output type: {final_output_type} is not needed by the request") + continue + + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # Initialize role before conditional blocks to avoid UnboundLocalError + # when handling audio/image responses + role = self.get_chat_request_role(request) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration_dict[final_output_type] and final_output_type == "text": + num_cached_tokens = res.num_cached_tokens + # Send first response for each choice with role + # NOTE: num_choices defaults to 1 so this usually executes once per request + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None, + ) + + # return prompt_token_ids at the first chunk ever + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + prompt_token_ids=(res.prompt_token_ids if request.return_token_ids else None), + modality=final_output_type, + ) + + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo: + last_msg_content: str | list[dict[str, str]] = "" + if conversation and "content" in conversation[-1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + logprobs=None, + finish_reason=None, + ) + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + modality=final_output_type, + ) + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration_dict[final_output_type] = False + + if final_output_type == "text": + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + + if self.use_harmony: + harmony_parser = harmony_parsers[i] + prev_recipient = harmony_parser.current_recipient + delta_text = "" + for token_id in output.token_ids: + harmony_parser.process(token_id) + delta_text += harmony_parser.last_content_delta or "" + cur_channel = harmony_parser.current_channel + cur_recipient = harmony_parser.current_recipient + else: + # output.text is cumulative, extract only the delta portion + previous_text = previous_texts[i] if previous_texts else "" + if output.text is not None: + delta_text = output.text[len(previous_text) :] + else: + delta_text = "" + + if not delta_text and not output.token_ids and not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + delta_message: DeltaMessage | None + + # just update previous_texts and previous_token_ids + if tool_choice_auto or self.reasoning_parser: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + # avoid the None + list error. + if previous_token_ids: + current_token_ids = previous_token_ids + as_list(output.token_ids) + else: + current_token_ids = as_list(output.token_ids) + + if self.use_harmony: + if cur_channel == "final": + delta_message = DeltaMessage(content=delta_text) + elif cur_channel == "analysis": + if request.include_reasoning: + delta_message = DeltaMessage(reasoning=delta_text) + else: + delta_message = None + elif ( + cur_channel == "commentary" and cur_recipient and cur_recipient.startswith("functions.") + ): + # Count completed tool calls to determine index + base_index = 0 + for msg in harmony_parser.messages: + if ( + msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith("functions.") + ): + base_index += 1 + + if prev_recipient != cur_recipient: + tool_name = cur_recipient.split("functions.", 1)[1] + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ] + ) + elif delta_text: + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall(arguments=delta_text), + ) + ] + ) + else: + delta_message = None + + if delta_message is not None: + harmony_tools_streamed[i] = True + else: + delta_message = None + # handle streaming deltas for tools with named tool_choice + elif tool_choice_function_name: + if ( + self.reasoning_parser + and not reasoning_end_arr[i] + and not reasoning_parser.is_reasoning_end(previous_token_ids) + ): + assert reasoning_parser is not None + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) + # When encountering think end id in delta_token_ids + # or think end id in prompt_token_ids + # i.e {"enable_thinking": False}, + # set reasoning status to end. + # Only keep 'content', remove 'reasoning'. + if reasoning_parser.is_reasoning_end(as_list(output.token_ids)) or ( + res.prompt_token_ids and reasoning_parser.is_reasoning_end(res.prompt_token_ids) + ): + reasoning_end_arr[i] = True + if delta_message and delta_message.content: + # This need to be added to next `delta_text` + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + else: + # Just to add remaining `content` + if self.reasoning_parser: + delta_text = previous_text + delta_text + current_text = "" + + if function_name_returned[i]: + delta_tool_call = DeltaToolCall( + function=DeltaFunctionCall(arguments=delta_text), + index=i, + ) + else: + delta_tool_call = DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text, + ), + index=i, + ) + function_name_returned[i] = True + + delta_message = DeltaMessage( + tool_calls=[ + delta_tool_call, + ] + ) + tools_streamed[i] = True + + elif request.tool_choice == "required": + assert previous_texts is not None + previous_text = previous_texts[i] + current_text = previous_text + delta_text + fn_name_returned = function_name_returned[i] + output_token_ids = as_list(output.token_ids) + + if ( + self.reasoning_parser is not None + and not reasoning_end_arr[i] + and res.prompt_token_ids + and reasoning_parser.is_reasoning_end(res.prompt_token_ids) + ): + reasoning_end_arr[i] = True + + if self.reasoning_parser and not reasoning_end_arr[i]: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) + if reasoning_parser.is_reasoning_end(output_token_ids): + reasoning_end_arr[i] = True + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + # reasoning ended + current_text = "" + + else: + # either finished reasoning or no reasoning at all + content = current_text + + delta_message, function_name_returned[i] = self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=content, + delta_text=delta_text, + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt, + ) + if ( + delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0].id is not None + ): + history_tool_call_cnt += 1 + tools_streamed[i] = True + + # handle streaming deltas for tools with "auto" tool choice + # and reasoning parser + elif tool_choice_auto and self.reasoning_parser: + assert tool_parser is not None + assert reasoning_parser is not None + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) + if not reasoning_end_arr[i]: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) + # When encountering think end id in prompt_token_ids + # i.e {"enable_thinking": False}, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning'. + if res.prompt_token_ids and reasoning_parser.is_reasoning_end(res.prompt_token_ids): + reasoning_end_arr[i] = True + current_token_ids = output_token_ids + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning'. + if reasoning_parser.is_reasoning_end(output_token_ids): + reasoning_end_arr[i] = True + current_token_ids = reasoning_parser.extract_content_ids(output_token_ids) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + # handle tool calls only after reasoning is done, + else: + delta_token_ids = output_token_ids + # First time to tool call, + # add the remaining text and token ids + # to delta from previous + if not added_content_delta_arr[i]: + added_content_delta_arr[i] = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + + # when only reasoning + elif self.reasoning_parser: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # update the previous values for the next iteration + if (tool_choice_auto or self.reasoning_parser) and not self.use_harmony: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + else: + # Update for comprehensive logging even in simple case + assert previous_texts is not None + previous_texts[i] += delta_text + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + if output.finish_reason is None and not request.return_token_ids: + continue + delta_message = DeltaMessage() + + # Log streaming delta if output logging is enabled + if self.enable_log_outputs and self.request_logger: + delta_content = "" + if delta_message.content: + delta_content = delta_message.content + elif delta_message.tool_calls: + delta_content = "".join( + tc.function.arguments + for tc in delta_message.tool_calls + if tc.function and tc.function.arguments + ) + + if delta_content: + self.request_logger.log_outputs( + request_id=request_id, + outputs=delta_content, + output_token_ids=as_list(output.token_ids), + finish_reason=output.finish_reason, + is_streaming=True, + delta=True, + ) + + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None, + token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + ) + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using structured outputs + auto_tools_called = False + if tool_parser: + auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens(delta_message, output) and tool_parser: + latest_delta_len = 0 + if ( + isinstance( + delta_message.tool_calls[0].function, + DeltaFunctionCall, + ) + ) and isinstance(delta_message.tool_calls[0].function.arguments, str): + latest_delta_len = len(delta_message.tool_calls[0].function.arguments) + + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get("arguments", {}), + ensure_ascii=False, + ) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[index] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace(actual_call, "", 1) + # set that as a delta message + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=index, + function=DeltaFunctionCall(arguments=remaining_call).model_dump( + exclude_none=True + ), + ) + ] + ) + + # Send the finish response for each request.n only once + # In OpenAI's API, when a tool is called, the + # finish_reason is: + # "tool_calls" for "auto" or "required" tool calls, + # and "stop" for named tool calls. + if ( + auto_tools_called + or (tools_streamed[i] and not tool_choice_function_name) + or (self.use_harmony and harmony_tools_streamed[i]) + ): + finish_reason_ = "tool_calls" + else: + finish_reason_ = output.finish_reason if output.finish_reason else "stop" + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=finish_reason_, + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + ) + + finish_reason_sent[i] = True + + choice_data = maybe_filter_parallel_tool_calls(choice_data, request) + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + modality=final_output_type, + ) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + elif final_output_type == "audio": + role = self.get_chat_request_role(request) + choices_data = self._create_audio_choice(omni_res, role, request, stream=True) + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=choices_data, + model=model_name, + modality=final_output_type, + ) + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + else: + logger.warning(f"Unsupported streaming final output type: {final_output_type}") + continue + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if include_usage: + completion_tokens = sum(previous_num_tokens) + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo(cached_tokens=num_cached_tokens) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage, + ) + final_usage_data = final_usage_chunk.model_dump_json(exclude_unset=True, exclude_none=True) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens, + ) + + # Log complete streaming response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + # Log the complete response for each choice + for i in range(num_choices): + full_text = ( + previous_texts[i] + if previous_texts and i < len(previous_texts) + else f"<streaming_complete: {previous_num_tokens[i]} tokens>" + ) + self.request_logger.log_outputs( + request_id=request_id, + outputs=full_text, + output_token_ids=None, # Consider also logging all token IDs + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) + + except Exception as e: + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(e) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: TokenizerLike, + request_metadata: RequestResponseMetadata, + ) -> ErrorResponse | ChatCompletionResponse: + created_time = int(time.time()) + final_res: RequestOutput | None = None + + final_outputs: list[OmniRequestOutput] = [] + try: + async for res in result_generator: + final_outputs.append(res) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + return self.create_error_response(e) + + assert final_outputs is not None + + choices: list[ChatCompletionResponseChoice] = [] + + usage = UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0) + role = self.get_chat_request_role(request) + prompt_logprobs = None + prompt_token_ids = None + kv_transfer_params = None + + # Build requested modalities set for filtering + requested_modalities = ( + set(request.modalities) if hasattr(request, "modalities") and request.modalities else None + ) + + for omni_outputs in final_outputs: + choices_data = [] + if omni_outputs.request_output is not None and not getattr(omni_outputs.request_output, "finished", False): + continue + + # Filter outputs based on requested modalites + if requested_modalities is not None and omni_outputs.final_output_type not in requested_modalities: + logger.warning(f"final output type: {omni_outputs.final_output_type} is not needed by the request") + continue + + if omni_outputs.final_output_type == "text": + ( + choices_data, + usage, + prompt_logprobs, + prompt_token_ids, + kv_transfer_params, + ) = self._create_text_choice(request, omni_outputs, tokenizer, conversation, role) + elif omni_outputs.final_output_type == "audio": + choices_data = self._create_audio_choice(omni_outputs, role, request, stream=False) + elif omni_outputs.final_output_type == "image": + choices_data = self._create_image_choice(omni_outputs, role, request, stream=False) + else: + logger.warning(f"Unsupported final output type: {omni_outputs.final_output_type}") + continue + choices.extend(choices_data) + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=prompt_logprobs, + prompt_token_ids=prompt_token_ids, + kv_transfer_params=kv_transfer_params, + ) + + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + for choice in choices: + output_text = "" + if choice.message.content: + output_text = choice.message.content + elif choice.message.tool_calls: + # For tool calls, log the function name and arguments + tool_call_descriptions = [] + for tc in choice.message.tool_calls: + if hasattr(tc.function, "name") and hasattr(tc.function, "arguments"): + tool_call_descriptions.append(f"{tc.function.name}({tc.function.arguments})") + tool_calls_str = ", ".join(tool_call_descriptions) + output_text = f"[tool_calls: {tool_calls_str}]" + + if output_text: + # Get the corresponding output token IDs + output_token_ids = None + if choice.index < len(final_res.outputs): + output_token_ids = final_res.outputs[choice.index].token_ids + + self.request_logger.log_outputs( + request_id=request_id, + outputs=output_text, + output_token_ids=output_token_ids, + finish_reason=choice.finish_reason, + is_streaming=False, + delta=False, + ) + + return response + + def _create_text_choice( + self, + request: ChatCompletionRequest, + omni_outputs: OmniRequestOutput, + tokenizer: TokenizerLike, + conversation: list[ConversationMessage], + role: str, + ): + final_res = omni_outputs.request_output + if self.tool_call_id_type == "kimi_k2": + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 + + choices: list[ChatCompletionResponseChoice] = [] + + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + tool_call_info = None + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + + if self.use_harmony: + reasoning_content, content, _ = parse_chat_output(token_ids) + if not request.include_reasoning: + reasoning_content = None + + if self.tool_parser is not None: + tool_parser = self.tool_parser(tokenizer) + # NOTE: We use token_ids for openai tool parser + tool_call_info = tool_parser.extract_tool_calls( + "", + request=request, + token_ids=token_ids, # type: ignore + ) + content = tool_call_info.content + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + tool_calls=tool_call_info.tool_calls, + ) + else: + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + ) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason=( + "tool_calls" + if (tool_call_info is not None and tool_call_info.tools_called) + else (output.finish_reason if output.finish_reason else "stop") + ), + stop_reason=output.stop_reason, + ) + choices.append(choice_data) + continue + + if self.reasoning_parser: + try: + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + reasoning_parser = self.reasoning_parser( + tokenizer, + chat_template_kwargs=chat_template_kwargs, # type: ignore + ) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(e) + # If the reasoning parser is enabled, + # tool calls are extracted exclusively from the content. + reasoning_content, content = reasoning_parser.extract_reasoning(output.text, request=request) + if not request.include_reasoning: + reasoning_content = None + else: + reasoning_content = None + content = output.text + + auto_tools_called = False + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools or not self.tool_parser) and ( + not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) + and request.tool_choice != "required" + ): + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_call_class = MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content="", + tool_calls=[ + tool_call_class( + function=FunctionCall( + name=request.tool_choice.function.name, + arguments=content, + ) + ) + ], + ) + + elif request.tool_choice and request.tool_choice == "required": + tool_call_class = MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall + + # the fields of FunctionDefinition are a superset of the + # tool call outputs and can be used for parsing + assert content is not None + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) + tool_call_ids = [] + for tool_call in tool_calls: + tool_call_ids.append( + make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt, + ) + ) + history_tool_call_cnt += 1 + message = ChatMessage( + role=role, + content="", + tool_calls=[ + tool_call_class( + id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + ), + ) + for i, tool_call in enumerate(tool_calls) + ], + reasoning_content=reasoning_content, + ) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + # handle when there are tools and tool choice is auto + elif ( + request.tools + and (request.tool_choice == "auto" or request.tool_choice is None) + and self.enable_auto_tools + and self.tool_parser + ): + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + return self.create_error_response(e) + + tool_call_info = tool_parser.extract_tool_calls(content if content is not None else "", request=request) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls, + ) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + ret_content = content + + # try to use content return from tool parser first, + # tool parser may do some modify for the content. + if tool_call_info.content and len(tool_call_info.content) > 0: + ret_content = tool_call_info.content + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=ret_content, + ) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine if tools should be extracted. " + "Returning a standard chat completion." + ) + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason=( + "tool_calls" if auto_tools_called else output.finish_reason if output.finish_reason else "stop" + ), + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + ) + choices.append(choice_data) + + if request.echo: + last_msg_content: str | list[dict[str, str]] = "" + if conversation and "content" in conversation[-1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg["text"] for msg in last_msg_content) + + for choice in choices: + full_message = last_msg_content + (choice.message.content or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo(cached_tokens=final_res.num_cached_tokens) + + prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) + prompt_token_ids = final_res.prompt_token_ids if request.return_token_ids else None + kv_transfer_params = final_res.kv_transfer_params + + return choices, usage, prompt_logprobs, prompt_token_ids, kv_transfer_params + + def _create_audio_choice( + self, omni_outputs: OmniRequestOutput, role: str, request: ChatCompletionRequest, stream: bool = False + ): + choices: list[ChatCompletionResponseChoice] = [] + final_res = omni_outputs.request_output + # OMNI: Access multimodal_output from CompletionOutput (outputs[0]), not from RequestOutput + # Reference: examples/offline_inference/qwen3_omni/end2end.py line 421 + audio_data = final_res.outputs[0].multimodal_output.get("audio") + if stream: + audio_tensor = audio_data[-1].float().detach().cpu().numpy() + else: + if isinstance(audio_data, list): + audio_data = torch.cat(audio_data, dim=-1) + audio_tensor = audio_data.float().detach().cpu().numpy() + + # Ensure audio is 1D (flatten if needed) + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.flatten() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=24000, + response_format="wav", + speed=1.0, + stream_format="audio", + base64_encode=True, + ) + + audio_response: AudioResponse = self.create_audio(audio_obj) + audio_base64 = audio_response.audio_data + + # Generate unique ID for the audio + audio_id = f"audio-{uuid.uuid4().hex[:16]}" + + # Set expiration time (e.g., 24 hours from now) as Unix timestamp + expires_at = int((datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()) + + # Create OpenAIChatCompletionAudio object with all required fields + audio_obj = OpenAIChatCompletionAudio( + id=audio_id, + data=audio_base64, + expires_at=expires_at, + transcript="", # Empty transcript if not available + ) + + for output in final_res.outputs: + if stream: + choice_data = ChatCompletionResponseStreamChoice( + index=output.index, + delta=DeltaMessage(role=role, content=audio_base64), + logprobs=None, + finish_reason="stop", + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + ) + else: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role=role, audio=audio_obj), + logprobs=None, + finish_reason="stop", + stop_reason=None, + ) + choices.append(choice_data) + return choices + + def _create_image_choice( + self, omni_outputs: OmniRequestOutput, role: str, request: ChatCompletionRequest, stream: bool = False + ): + """Create chat completion response choices for image output. + + Converts image tensor or PIL Image output from diffusion models + into base64-encoded image data for API response. + + Args: + omni_outputs: Output containing image data from diffusion stage + role: The role for the response message (e.g., "assistant") + + Returns: + List of ChatCompletionResponseChoice with image content + """ + from PIL import Image + + choices: list[ChatCompletionResponseChoice] = [] + final_res = omni_outputs.request_output + + # Handle different image output formats + images = [] + + # First check omni_outputs.images directly (for diffusion mode via from_diffusion) + if omni_outputs.images: + images = omni_outputs.images + # Fall back to request_output for pipeline mode + # OMNI: Access multimodal_output from CompletionOutput (outputs[0]), not from RequestOutput + elif final_res is not None and final_res.outputs: + completion_output = final_res.outputs[0] + if hasattr(completion_output, "multimodal_output") and completion_output.multimodal_output: + image_data = completion_output.multimodal_output.get("image") + if image_data is not None: + if isinstance(image_data, Image.Image): + images.append(image_data) + elif hasattr(image_data, "cpu"): # Tensor + import numpy as np + + # Convert tensor to PIL Image + img_array = image_data.float().detach().cpu().numpy() + # Handle different tensor formats (CHW -> HWC) + if img_array.ndim == 3 and img_array.shape[0] in [1, 3, 4]: + img_array = np.transpose(img_array, (1, 2, 0)) + # Normalize to 0-255 + if img_array.max() <= 1.0: + img_array = (img_array * 255).astype(np.uint8) + else: + img_array = img_array.astype(np.uint8) + # Handle grayscale + if img_array.ndim == 2: + images.append(Image.fromarray(img_array, mode="L")) + elif img_array.shape[-1] == 1: + images.append(Image.fromarray(img_array.squeeze(-1), mode="L")) + elif img_array.shape[-1] == 3: + images.append(Image.fromarray(img_array, mode="RGB")) + elif img_array.shape[-1] == 4: + images.append(Image.fromarray(img_array, mode="RGBA")) + elif hasattr(final_res, "images") and final_res.images: + images = final_res.images + + # Convert images to base64 + image_contents = [] + for img in images: + with BytesIO() as buffer: + img.save(buffer, format="PNG") + img_bytes = buffer.getvalue() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + image_contents.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{img_base64}", + }, + } + ) + + # Create message content + if len(image_contents) == 1: + content = image_contents + elif len(image_contents) > 1: + content = image_contents + else: + content = [{"type": "text", "text": "Image generation completed but no images were produced."}] + + # Create response choice + # Use model_construct to bypass validation for multimodal content + # (ChatMessage.content only accepts str, but we need list for images) + # Then use object.__setattr__ to directly set the field, bypassing Pydantic's type checking + import warnings as warnings_module + + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning, module="pydantic") + message = ChatMessage.model_construct(role=role) + object.__setattr__(message, "content", content) + # Mark content as set in fields_set to ensure proper serialization + if hasattr(message, "__pydantic_fields_set__"): + message.__pydantic_fields_set__.add("content") + choice_data = ChatCompletionResponseChoice( + index=0, + message=message, + logprobs=None, + finish_reason="stop", + stop_reason=None, + ) + choices.append(choice_data) + + return choices + + # ==================== Diffusion Mode Methods ==================== + + async def _create_diffusion_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Request | None = None, + ) -> ChatCompletionResponse | ErrorResponse: + """Generate images via chat completion interface for diffusion models. + + Args: + request: Chat completion request + raw_request: Raw FastAPI request object + + Returns: + ChatCompletionResponse with generated images or ErrorResponse + """ + try: + request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + created_time = int(time.time()) + + # Convert messages to dict format + messages = [] + for msg in request.messages: + if hasattr(msg, "model_dump"): + messages.append(msg.model_dump()) + elif isinstance(msg, dict): + messages.append(msg) + else: + messages.append({"role": getattr(msg, "role", "user"), "content": getattr(msg, "content", "")}) + + # Extract prompt and images from messages + prompt, reference_images = self._extract_diffusion_prompt_and_images(messages) + + if not prompt: + return self._create_error_response("No text prompt found in messages") + + # Extract generation parameters from extra_body (preferred) + # Reference: text_to_image.py and text_to_video.py for supported parameters + extra_body = getattr(request, "extra_body", None) or {} + + # Parse size if provided (supports "1024x1024" format) + height = extra_body.get("height") + width = extra_body.get("width") + if "size" in extra_body: + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except ValueError: + logger.warning("Invalid size format: %s", extra_body.get("size")) + + # Get request parameters from extra_body + # Text-to-image parameters (ref: text_to_image.py) + num_inference_steps = extra_body.get("num_inference_steps", 50) + guidance_scale = extra_body.get("guidance_scale") + true_cfg_scale = extra_body.get("true_cfg_scale") # Qwen-Image specific + seed = extra_body.get("seed") + negative_prompt = extra_body.get("negative_prompt") + num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1) + + # Text-to-video parameters (ref: text_to_video.py) + num_frames = extra_body.get("num_frames") + guidance_scale_2 = extra_body.get("guidance_scale_2") # For video high-noise CFG + lora_body = extra_body.get("lora") + + logger.info( + "Diffusion chat request %s: prompt=%r, ref_images=%d, params=%s", + request_id, + prompt[:50] + "..." if len(prompt) > 50 else prompt, + len(reference_images), + {k: v for k, v in extra_body.items() if v is not None}, + ) + + # Decode reference images if provided + pil_images: list[Image.Image] = [] + for img_b64 in reference_images: + try: + img_bytes = base64.b64decode(img_b64) + pil_images.append(Image.open(BytesIO(img_bytes))) + except Exception as e: + logger.warning("Failed to decode reference image: %s", e) + + # Build generation kwargs + gen_prompt: OmniTextPrompt = { + "prompt": prompt, + "negative_prompt": negative_prompt, + } + gen_params = OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + height=height, + width=width, + num_outputs_per_prompt=num_outputs_per_prompt, + seed=seed, + ) + + if guidance_scale is not None: + gen_params.guidance_scale = guidance_scale + + # Add Qwen-Image specific parameter + if true_cfg_scale is not None: + gen_params.true_cfg_scale = true_cfg_scale + + # Add video generation parameters if set + if num_frames is not None: + gen_params.num_frames = num_frames + if guidance_scale_2 is not None: + gen_params.guidance_scale_2 = guidance_scale_2 + + # Parse per-request LoRA (works for both AsyncOmniDiffusion and AsyncOmni). + if lora_body and isinstance(lora_body, dict): + try: + lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") + lora_path = ( + lora_body.get("local_path") + or lora_body.get("path") + or lora_body.get("lora_path") + or lora_body.get("lora_local_path") + ) + # using "or" directly here may be buggy if `scale=0` + lora_scale = lora_body.get("scale") + if lora_scale is None: + lora_scale = lora_body.get("lora_scale") + lora_int_id = lora_body.get("int_id") + if lora_int_id is None: + lora_int_id = lora_body.get("lora_int_id") + if lora_int_id is None and lora_path: + lora_int_id = stable_lora_int_id(str(lora_path)) + if lora_name and lora_path: + lora_req = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) + gen_params.lora_request = lora_req + if lora_scale is not None: + gen_params.lora_scale = float(lora_scale) + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + # Add reference image if provided + if pil_images: + if len(pil_images) == 1: + gen_prompt["multi_modal_data"] = {} + gen_prompt["multi_modal_data"]["image"] = pil_images[0] + else: + od_config = getattr(self._diffusion_engine, "od_config", None) + supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) + if od_config is None: + # TODO: entry is asyncOmni. We hack the od config here. + supports_multimodal_inputs = True + if supports_multimodal_inputs: + gen_prompt["multi_modal_data"] = {} + gen_prompt["multi_modal_data"]["image"] = pil_images + else: + return self._create_error_response( + "Multiple input images are not supported by the current diffusion model. " + "For multi-image editing, start the server with Qwen-Image-Edit-2509 " + "and send multiple images in the user message content.", + status_code=400, + ) + + # Generate image + # Handle both AsyncOmniDiffusion (returns OmniRequestOutput) and AsyncOmni (returns AsyncGenerator) + if hasattr(self._diffusion_engine, "stage_list"): + # AsyncOmni: iterate through async generator to get final output + diffusion_engine = cast(AsyncOmni, self._diffusion_engine) + result = None + async for output in diffusion_engine.generate( + prompt=gen_prompt, + sampling_params_list=[gen_params], # Pass as single-stage params + request_id=request_id, + ): + result = output + if result is None: + return self._create_error_response("No output generated from AsyncOmni") + else: + # AsyncOmniDiffusion: direct call + diffusion_engine = cast(AsyncOmniDiffusion, self._diffusion_engine) + result = await diffusion_engine.generate( + prompt=gen_prompt, + sampling_params=gen_params, + request_id=request_id, + ) + # Extract images from result + # Handle nested OmniRequestOutput structure where images might be in request_output + images = getattr(result.request_output, "images", []) + + # Convert images to base64 content + image_contents: list[dict[str, Any]] = [] + for img in images: + with BytesIO() as buffer: + img.save(buffer, format="PNG") + img_bytes = buffer.getvalue() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + image_contents.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{img_base64}", + }, + } + ) + + # Build response + if not image_contents: + content = "Image generation completed but no images were produced." + else: + content = image_contents + + # Use model_construct to bypass validation for multimodal content + # (ChatMessage.content only accepts str, but we need list for images) + # Then use object.__setattr__ to directly set the field, bypassing Pydantic's type checking + import warnings as warnings_module + + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning, module="pydantic") + message = ChatMessage.model_construct(role="assistant") + object.__setattr__(message, "content", content) + # Mark content as set in fields_set to ensure proper serialization + if hasattr(message, "__pydantic_fields_set__"): + message.__pydantic_fields_set__.add("content") + choice = ChatCompletionResponseChoice.model_construct( + index=0, + message=message, + finish_reason="stop", + logprobs=None, + stop_reason=None, + ) + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=self._diffusion_model_name, + choices=[choice], + usage=UsageInfo( + prompt_tokens=len(prompt.split()), + completion_tokens=1, + total_tokens=len(prompt.split()) + 1, + ), + ) + + logger.info( + "Diffusion chat completed for request %s: %d images", + request_id, + len(images), + ) + + return response + + except Exception as e: + logger.exception("Diffusion chat completion failed: %s", e) + return self._create_error_response( + f"Image generation failed: {str(e)}", + status_code=500, + ) + + def _extract_diffusion_prompt_and_images( + self, + messages: list[dict[str, Any]], + ) -> tuple[str, list[str]]: + """Extract text prompt and base64 images from chat messages. + + Args: + messages: List of chat messages + + Returns: + Tuple of (prompt_text, list_of_base64_images) + """ + prompt_parts: list[str] = [] + images: list[str] = [] + + for message in messages: + role = message.get("role", "") + if role != "user": + continue + + content = message.get("content", "") + + # String content + if isinstance(content, str): + prompt_parts.append(content) + continue + + # List of content items + if isinstance(content, list): + for item in content: + if isinstance(item, str): + prompt_parts.append(item) + elif isinstance(item, dict): + # Handle {"type": "text", "text": "..."} format + if item.get("type") == "text": + prompt_parts.append(item.get("text", "")) + # Handle {"text": "..."} format + elif "text" in item and "type" not in item: + prompt_parts.append(item["text"]) + # Handle {"type": "image_url", "image_url": {"url": "..."}} + elif item.get("type") == "image_url": + url = item.get("image_url", {}).get("url", "") + if url.startswith("data:image"): + try: + _, b64_data = url.split(",", 1) + images.append(b64_data) + except ValueError: + logger.warning("Invalid data URL format") + # Handle {"image": "base64..."} format + elif "image" in item: + images.append(item["image"]) + + prompt = " ".join(prompt_parts).strip() + return prompt, images + + def _create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: int = 400, + ) -> ErrorResponse: + """Create an error response following OpenAI error format.""" + return ErrorResponse( + error=ErrorInfo( + message=message, + type=err_type, + code=status_code, + ) + ) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4b7829f84877ffc4721cb591c262b2db1c5c78 --- /dev/null +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -0,0 +1,310 @@ +import asyncio +from typing import Any + +from fastapi import Request +from fastapi.responses import Response +from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.logger import init_logger +from vllm.utils import random_uuid + +from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin +from vllm_omni.entrypoints.openai.protocol.audio import ( + AudioResponse, + CreateAudio, + OpenAICreateSpeechRequest, +) +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) + +# TTS Configuration (currently supports Qwen3-TTS) +_TTS_MODEL_STAGES: set[str] = {"qwen3_tts"} +_TTS_LANGUAGES: set[str] = { + "Auto", + "Chinese", + "English", + "Japanese", + "Korean", + "German", + "French", + "Russian", + "Portuguese", + "Spanish", + "Italian", +} +_TTS_MAX_INSTRUCTIONS_LENGTH = 500 +_TTS_MAX_NEW_TOKENS_MIN = 1 +_TTS_MAX_NEW_TOKENS_MAX = 4096 + + +class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Load supported speakers + self.supported_speakers = self._load_supported_speakers() + logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") + + def _load_supported_speakers(self) -> set[str]: + """Load supported speakers (case-insensitive) from the model configuration.""" + try: + talker_config = self.engine_client.model_config.hf_config.talker_config + + # Check for speakers in either spk_id or speaker_id + for attr_name in ["spk_id", "speaker_id"]: + speakers_dict = getattr(talker_config, attr_name, None) + if speakers_dict and isinstance(speakers_dict, dict): + # Normalize to lowercase for case-insensitive matching + return {speaker.lower() for speaker in speakers_dict.keys()} + + logger.warning("No speakers found in talker_config (checked spk_id and speaker_id)") + except Exception as e: + logger.warning(f"Could not load speakers from model config: {e}") + + return set() + + def _is_tts_model(self) -> bool: + """Check if the current model is a supported TTS model.""" + stage_list = getattr(self.engine_client, "stage_list", None) + if stage_list: + for stage in stage_list: + model_stage = getattr(stage, "model_stage", None) + if model_stage in _TTS_MODEL_STAGES: + return True + return False + + def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate TTS request parameters. Returns error message or None.""" + task_type = request.task_type or "CustomVoice" + + # Normalize voice to lowercase for case-insensitive matching + if request.voice is not None: + request.voice = request.voice.lower() + + # Validate input is not empty + if not request.input or not request.input.strip(): + return "Input text cannot be empty" + + # Validate language + if request.language is not None and request.language not in _TTS_LANGUAGES: + return f"Invalid language '{request.language}'. Supported: {', '.join(sorted(_TTS_LANGUAGES))}" + + # Validate speaker for CustomVoice task + if task_type == "CustomVoice" and request.voice is not None: + if self.supported_speakers and request.voice not in self.supported_speakers: + return f"Invalid speaker '{request.voice}'. Supported: {', '.join(sorted(self.supported_speakers))}" + + # Validate Base task requirements + if task_type == "Base": + if request.ref_audio is None: + return "Base task requires 'ref_audio' for voice cloning" + # Validate ref_audio format + if not (request.ref_audio.startswith(("http://", "https://")) or request.ref_audio.startswith("data:")): + return "ref_audio must be a URL (http/https) or base64 data URL (data:...)" + + # Validate cross-parameter dependencies + if task_type != "Base": + if request.ref_text is not None: + return "'ref_text' is only valid for Base task" + if request.x_vector_only_mode is not None: + return "'x_vector_only_mode' is only valid for Base task" + + # Validate VoiceDesign task requirements + if task_type == "VoiceDesign" and not request.instructions: + return "VoiceDesign task requires 'instructions' to describe the voice" + + # Validate instructions length + if request.instructions and len(request.instructions) > _TTS_MAX_INSTRUCTIONS_LENGTH: + return f"Instructions too long (max {_TTS_MAX_INSTRUCTIONS_LENGTH} characters)" + + # Validate max_new_tokens range + if request.max_new_tokens is not None: + if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN: + return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}" + if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX: + return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}" + + return None + + def _build_tts_prompt(self, text: str) -> str: + """Build TTS prompt from input text.""" + return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + + def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: + """Build TTS parameters from request. + + Processes each parameter if present, skips if not. + Values are wrapped in lists as required by the model. + """ + params: dict[str, Any] = {} + + # Text content (always required) + params["text"] = [request.input] + + # Task type + if request.task_type is not None: + params["task_type"] = [request.task_type] + else: + params["task_type"] = ["CustomVoice"] + + # Language + if request.language is not None: + params["language"] = [request.language] + else: + params["language"] = ["Auto"] + + # Speaker (voice) + if request.voice is not None: + params["speaker"] = [request.voice] + elif params["task_type"][0] == "CustomVoice": + params["speaker"] = ["Vivian"] # Default for CustomVoice + + # Instructions for style/emotion control + if request.instructions is not None: + params["instruct"] = [request.instructions] + else: + params["instruct"] = [""] + + # Voice clone parameters (used with Base task) + if request.ref_audio is not None: + params["ref_audio"] = [request.ref_audio] + if request.ref_text is not None: + params["ref_text"] = [request.ref_text] + if request.x_vector_only_mode is not None: + params["x_vector_only_mode"] = [request.x_vector_only_mode] + + # Let the model's generate_config supply defaults unless the user + # explicitly overrides max_new_tokens in the request. + if request.max_new_tokens is not None: + params["max_new_tokens"] = [request.max_new_tokens] + + return params + + async def create_speech( + self, + request: OpenAICreateSpeechRequest, + raw_request: Request | None = None, + ): + """ + Create Speech API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createSpeech + for the API specification. This API mimics the OpenAI + Create Speech API. + + For Qwen3-TTS models, additional parameters are supported: + - task_type: "CustomVoice", "VoiceDesign", or "Base" + - language: Language code (e.g., "Chinese", "English", "Auto") + - voice: Speaker name (e.g., "Vivian", "Ryan") for CustomVoice + - instructions: Voice style/emotion instructions + - ref_audio: Reference audio for voice cloning (Base task) + - ref_text: Transcript of reference audio (Base task) + - x_vector_only_mode: Use speaker embedding only (Base task) + + NOTE: Streaming audio generation is not currently supported. + """ + + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + if self.engine_client.errored: + raise self.engine_client.dead_error + + request_id = f"speech-{random_uuid()}" + + try: + if self._is_tts_model(): + # Validate TTS parameters + validation_error = self._validate_tts_request(request) + if validation_error: + return self.create_error_response(validation_error) + + # Build TTS parameters and prompt + tts_params = self._build_tts_params(request) + prompt_text = self._build_tts_prompt(request.input) + prompt = { + "prompt": prompt_text, + "additional_information": tts_params, + } + else: + # Fallback for unsupported models + tts_params = {} + prompt = {"prompt": request.input} + + logger.info( + "TTS speech request %s: text=%r, task_type=%s", + request_id, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + tts_params.get("task_type", ["unknown"])[0], + ) + + sampling_params_list = self.engine_client.default_sampling_params_list + + generator = self.engine_client.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["audio"], + ) + + final_output: OmniRequestOutput | None = None + async for res in generator: + final_output = res + + if final_output is None: + return self.create_error_response("No output generated from the model.") + + # Extract audio from output + # Audio can be in final_output.multimodal_output or final_output.request_output.multimodal_output + # Support both "audio" and "model_outputs" keys for compatibility with different models + audio_output = None + if hasattr(final_output, "multimodal_output") and final_output.multimodal_output: + audio_output = final_output.multimodal_output + if not audio_output and hasattr(final_output, "request_output"): + if final_output.request_output and hasattr(final_output.request_output, "multimodal_output"): + audio_output = final_output.request_output.multimodal_output + + # Check for audio data using either "audio" or "model_outputs" key + audio_key = None + if audio_output: + if "audio" in audio_output: + audio_key = "audio" + elif "model_outputs" in audio_output: + audio_key = "model_outputs" + + if not audio_output or audio_key is None: + return self.create_error_response("TTS model did not produce audio output.") + + audio_tensor = audio_output[audio_key] + sample_rate = audio_output.get("sr", 24000) + if hasattr(sample_rate, "item"): + sample_rate = sample_rate.item() + + # Convert tensor to numpy + if hasattr(audio_tensor, "float"): + audio_tensor = audio_tensor.float().detach().cpu().numpy() + + # Squeeze batch dimension if present, but preserve channel dimension for stereo + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.squeeze() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=int(sample_rate), + response_format=request.response_format or "wav", + speed=request.speed or 1.0, + stream_format=request.stream_format, + base64_encode=False, + ) + + audio_response: AudioResponse = self.create_audio(audio_obj) + return Response(content=audio_response.audio_data, media_type=audio_response.media_type) + + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + return self.create_error_response(e) + except Exception as e: + logger.exception("Speech generation failed: %s", e) + return self.create_error_response(f"Speech generation failed: {e}") diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74ad42f045e812eee87ac76143197500273d930a --- /dev/null +++ b/vllm_omni/entrypoints/stage_utils.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +import enum +import json +import logging +import os +from multiprocessing import shared_memory as _shm +from typing import Any + +from omegaconf import OmegaConf + +logger = logging.getLogger(__name__) + + +class OmniStageTaskType(enum.Enum): + GENERATE = "generate" + ABORT = "abort" + SHUTDOWN = "shutdown" + PROFILER_START = "profiler_start" + PROFILER_STOP = "profiler_stop" + + +SHUTDOWN_TASK = {"type": OmniStageTaskType.SHUTDOWN} + + +def is_profiler_task(task_type: OmniStageTaskType) -> bool: + return task_type in (OmniStageTaskType.PROFILER_START, OmniStageTaskType.PROFILER_STOP) + + +def set_stage_devices( + stage_id: int, + devices: str | int | None, + device_type: str | None = None, +) -> None: + """Configure per-stage device visibility and current device (CUDA or NPU). + + This function sets environment variables that control which devices are visible + to the process, and sets the current device. It must be called BEFORE worker + initialization so that workers see the correct devices. + + Args: + stage_id: Stage identifier for logging + devices: Device specification: + - Comma-separated string (e.g. "2,5,7"): interpreted as logical + indices against the current device visibility env var (e.g. + CUDA_VISIBLE_DEVICES/ASCEND_RT_VISIBLE_DEVICES) when present; + falls back to physical IDs if no mapping exists. Logical index 0 + is used as current device. + - Integer or digit-string: treat as logical index (0-based) into the + current device visibility mapping; map to physical device, then set + env var to this single device. + - None/"cpu": keep default visibility. + - Otherwise: set env var to the provided single device string. + device_type: Device type ("cuda" or "npu"). If None, auto-detects. + + Behavior: + - CUDA: Sets CUDA_VISIBLE_DEVICES and calls torch.cuda.set_device() + - NPU: Sets ASCEND_RT_VISIBLE_DEVICES and calls torch.npu.set_device() + """ + from vllm_omni.platforms import current_omni_platform + + if device_type is None: + device_type = current_omni_platform.device_type + + env_var = current_omni_platform.device_control_env_var + + try: + selected_physical: int | None = None + logical_idx: int | None = None + + if isinstance(devices, str) and "," in devices: + toks = [t.strip() for t in devices.split(",") if t.strip() != ""] + vis = os.environ.get(env_var) + mapped_devices: list[str] = [] + mapping: list[int] = [] + if vis: + try: + mapping = [int(x) for x in vis.split(",") if x.strip() != ""] + except Exception as e: + logger.debug("[Stage-%s] Failed to parse existing %s: %s", stage_id, env_var, e) + for tok in toks: + try: + idx = int(tok) + except Exception: + mapped_devices.append(tok) + continue + if mapping and 0 <= idx < len(mapping): + mapped_devices.append(str(mapping[idx])) + else: + mapped_devices.append(str(idx)) + mapped_devices_str = ",".join(mapped_devices) + os.environ[env_var] = mapped_devices_str + if toks: + try: + selected_physical = int(mapped_devices[0]) + logger.debug( + "[Stage-%s] Set %s to %s; logical 0 -> physical %s", + stage_id, + env_var, + mapped_devices_str, + selected_physical, + ) + except Exception as e: + logger.debug("[Stage-%s] Failed to parse first %s device: %s", stage_id, device_type, e) + selected_physical = None + elif isinstance(devices, (int, str)) and (isinstance(devices, int) or str(devices).isdigit()): + logical_idx = max(0, int(devices)) + vis = os.environ.get(env_var) + if vis: + try: + mapping = [int(x) for x in vis.split(",") if x.strip() != ""] + if 0 <= logical_idx < len(mapping): + selected_physical = mapping[logical_idx] + except Exception as e: + logger.debug("[Stage-%s] Failed to map logical index via %s: %s", stage_id, env_var, e) + selected_physical = None + if selected_physical is None: + selected_physical = int(logical_idx) + os.environ[env_var] = str(selected_physical) + logger.debug( + "[Stage-%s] Logical index %d -> physical %s; set %s to single device", + stage_id, + logical_idx + 1, + selected_physical, + env_var, + ) + elif devices in (None, "cpu"): + logger.debug("[Stage-%s] Using default device visibility (devices=%s)", stage_id, devices) + else: + selected_physical = int(str(devices)) + os.environ[env_var] = str(selected_physical) + logger.debug("[Stage-%s] Set %s to single device %s (fallback)", stage_id, env_var, selected_physical) + except Exception as e: + logger.warning("Failed to interpret devices for stage %s: %s", stage_id, e) + + +def serialize_obj(obj: Any) -> bytes: + """Serialize a Python object to bytes using centralized serializer (defaults to cloudpickle).""" + from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + + return OmniSerializer.serialize(obj) + + +def shm_write_bytes(payload: bytes, name: str | None = None) -> dict[str, Any]: + """Write bytes into SharedMemory and return meta dict {name,size}. + + Caller should close the segment; the receiver should unlink. + """ + try: + shm = _shm.SharedMemory(create=True, size=len(payload), name=name) + except FileExistsError: + if name: + # If name is specified and exists, unlink it and try again + try: + existing = _shm.SharedMemory(name=name) + existing.unlink() + except Exception: + pass + shm = _shm.SharedMemory(create=True, size=len(payload), name=name) + else: + raise + + mv = memoryview(shm.buf) + mv[: len(payload)] = payload + del mv + meta = {"name": shm.name, "size": len(payload)} + try: + shm.close() + except Exception as e: + logger.debug("Failed to close shared memory: %s", e) + return meta + + +def shm_read_bytes(meta: dict[str, Any]) -> bytes: + """Read bytes from SharedMemory by meta {name,size} and cleanup.""" + shm = _shm.SharedMemory(name=meta["name"]) # type: ignore[index] + mv = memoryview(shm.buf) + data = bytes(mv[: meta["size"]]) + del mv + try: + shm.close() + except Exception: + pass + try: + shm.unlink() + except Exception: + pass + return data + + +def _ensure_parent_dir(path: str) -> None: + """Ensure the parent directory for a file path exists (best-effort).""" + try: + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + except Exception: + pass + + +def append_jsonl(path: str, record: dict[str, Any]) -> None: + """Append a JSON record as one line to a JSONL file (best-effort). + + This is safe to call from multiple processes when each process writes + to a distinct file. For concurrent writes to the same file, OS append + semantics typically suffice, but no additional locking is provided. + """ + try: + _ensure_parent_dir(path) + line = json.dumps(record, ensure_ascii=False) + fd = os.open(path, os.O_APPEND | os.O_CREAT | os.O_WRONLY, 0o644) + with os.fdopen(fd, "a", encoding="utf-8") as f: + f.write(line + "\n") + except Exception: + logger.exception("Failed to append JSONL to %s", path) + + +def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]: + """Dump object to SHM if serialized size exceeds threshold. + + Returns (True, meta) when dumped; otherwise (False, original_obj). + """ + payload = serialize_obj(obj) + if len(payload) > threshold: + logger.debug(f"Dumping object to SHM with size: {len(payload)}") + return True, shm_write_bytes(payload, name=None) + return False, obj + + +def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) -> Any: + """Load object from container that may carry SHM or inline object. + + Deprecated: prefer `maybe_load_from_ipc_with_metrics` to also obtain + decode-time and size metrics. + """ + if shm_key in container: + from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + + return OmniSerializer.deserialize(shm_read_bytes(container[shm_key])) + return container[obj_key] + + +def maybe_load_from_ipc_with_metrics( + container: dict[str, Any], obj_key: str, shm_key: str +) -> tuple[Any, dict[str, float]]: + """Load object and return (object, metrics) with RX bytes and decode time. + + Metrics keys: + - rx_transfer_bytes: int + - rx_decode_time_ms: float + """ + import time as _time # local import to avoid overhead at module import + + from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + + t0 = _time.time() + if shm_key in container: + meta = container[shm_key] # type: ignore[index] + payload = shm_read_bytes(meta) + obj = OmniSerializer.deserialize(payload) + try: + rx_bytes = int(meta.get("size", len(payload))) # type: ignore[call-arg] + except Exception: + rx_bytes = len(payload) + else: + obj = container[obj_key] + try: + rx_bytes = len(serialize_obj(obj)) + except Exception: + rx_bytes = 0 + t1 = _time.time() + rx_decode_ms = (t1 - t0) * 1000.0 + return obj, { + "rx_transfer_bytes": int(rx_bytes), + "rx_decode_time_ms": float(rx_decode_ms), + } + + +def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict[str, Any]: + """Return a dict payload for IPC: inline (obj_key) or SHM (shm_key). + + When serialized size exceeds threshold, returns {shm_key: {name,size}}; + otherwise returns {obj_key: obj}. + """ + payload: dict[str, Any] = {} + use_shm, data = maybe_dump_to_shm(obj, threshold) + if use_shm: + payload[shm_key] = data + else: + payload[obj_key] = data + return payload + + +# Convert OmegaConf/objects to plain dicts +def _to_dict(x: Any) -> dict[str, Any]: + try: + if isinstance(x, dict): + return dict(x) + return OmegaConf.to_container(x, resolve=True) # type: ignore[arg-type] + except Exception: + try: + return dict(x) + except Exception: + return {} diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4ebff329034d08b2784fb6ad9c46839b0a4adf41 --- /dev/null +++ b/vllm_omni/entrypoints/utils.py @@ -0,0 +1,282 @@ +import os +from collections import Counter +from dataclasses import asdict, is_dataclass +from pathlib import Path +from typing import Any + +from omegaconf import OmegaConf +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_config, get_hf_file_to_dict +from vllm.transformers_utils.repo_utils import file_or_path_exists + +from vllm_omni.entrypoints.stage_utils import _to_dict +from vllm_omni.platforms import current_omni_platform + +# Get the project root directory (2 levels up from this file) +PROJECT_ROOT = Path(__file__).parent.parent.parent + +logger = init_logger(__name__) + + +def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None: + """Inject connector configuration into stage engine arguments.""" + # Prepare omni_kv_config dict + omni_conf_dict = {} + try: + # Access engine_args safely (might be OmegaConf or dict) + existing_args = stage.engine_args + if hasattr(existing_args, "get"): + _oc = existing_args.get("omni_kv_config", None) + if _oc: + if hasattr(_oc, "items"): # dict-like + omni_conf_dict = dict(_oc) + else: # object? + omni_conf_dict = _to_dict(_oc) + except Exception: + omni_conf_dict = {} + + # Inject connector info + omni_conf_dict["connector_config"] = omni_conn_cfg + omni_conf_dict["omni_from_stage"] = omni_from + omni_conf_dict["omni_to_stage"] = omni_to + + # Write back to engine_args + try: + if hasattr(stage.engine_args, "__setitem__"): + stage.engine_args["omni_kv_config"] = omni_conf_dict + else: + setattr(stage.engine_args, "omni_kv_config", omni_conf_dict) + except Exception as e: + # Fallback for OmegaConf or similar if direct set fails? + logger.error(f"Failed to inject omni connector config into stage: {e}") + + +def _try_get_class_name_from_diffusers_config(model: str) -> str | None: + """Try to get class name from diffusers model configuration files. + + Args: + model: Model name or path + + Returns: + Model type string if found, None otherwise + """ + model_index = get_hf_file_to_dict("model_index.json", model, revision=None) + if model_index and isinstance(model_index, dict) and "_class_name" in model_index: + logger.debug(f"Found model_type '{model_index['_class_name']}' in model_index.json") + return model_index["_class_name"] + + return None + + +def _convert_dataclasses_to_dict(obj: Any) -> Any: + """Recursively convert non-serializable objects to OmegaConf-compatible types. + + This is needed because OmegaConf cannot handle: + - Dataclass objects with Literal type annotations (e.g., StructuredOutputsConfig) + - Counter objects (from collections or vllm.utils) + - Set objects + - Other non-primitive types + """ + # IMPORTANT: Check Counter BEFORE dict, since Counter is a subclass of dict + # Handle Counter objects (convert to dict) + # Check by class name first to catch both collections.Counter and vllm.utils.Counter + if hasattr(obj, "__class__") and obj.__class__.__name__ == "Counter": + try: + return dict(obj) + except (TypeError, ValueError): + # If Counter can't be converted to dict, return empty dict + return {} + # Also check isinstance for collections.Counter (must be before dict check) + if isinstance(obj, Counter): + return dict(obj) + # Handle set objects (convert to list) + if isinstance(obj, set): + return list(obj) + # Handle dataclass objects + # Note: asdict() recursively converts nested dataclasses but not Counter objects, + # so we need to recursively process the result + if is_dataclass(obj): + result = asdict(obj) + # Recursively process the result to convert any Counter objects + return _convert_dataclasses_to_dict(result) + # Handle dictionaries (recurse into values) + # Note: This must come AFTER Counter check since Counter is a dict subclass + if isinstance(obj, dict): + return {k: _convert_dataclasses_to_dict(v) for k, v in obj.items()} + # Handle lists and tuples (recurse into items) + if isinstance(obj, (list, tuple)): + return type(obj)(_convert_dataclasses_to_dict(item) for item in obj) + # Try to convert any dict-like object (has keys/values methods) to dict + if hasattr(obj, "keys") and hasattr(obj, "values") and not isinstance(obj, (str, bytes)): + try: + return {k: _convert_dataclasses_to_dict(v) for k, v in obj.items()} + except (TypeError, ValueError, AttributeError): + # If conversion fails, return as-is + return obj + # Primitive types and other objects that OmegaConf can handle + return obj + + +def resolve_model_config_path(model: str) -> str: + """Resolve the stage config file path from the model name. + + Resolves stage configuration path based on the model type and device type. + First tries to find a device-specific YAML file from stage_configs/{device_type}/ + directory. If not found, falls back to the default config file. + + Args: + model: Model name or path (used to determine model_type) + + Returns: + String path to the stage configuration file + + Raises: + ValueError: If model_type cannot be determined + FileNotFoundError: If no stage config file exists for the model type + """ + # Try to get config from standard transformers format first + try: + hf_config = get_config(model, trust_remote_code=True) + model_type = hf_config.model_type + except (ValueError, Exception): + # If standard transformers format fails, try diffusers format + if file_or_path_exists(model, "model_index.json", revision=None): + model_type = _try_get_class_name_from_diffusers_config(model) + if model_type is None: + raise ValueError( + f"Could not determine model_type for diffusers model: {model}. " + f"Please ensure the model has 'model_type' in transformer/config.json or model_index.json" + ) + elif file_or_path_exists(model, "config.json", revision=None): + # Try to read config.json manually for custom models like Bagel that fail get_config + # but have a valid config.json with model_type + try: + config_dict = get_hf_file_to_dict("config.json", model, revision=None) + if config_dict and "model_type" in config_dict: + model_type = config_dict["model_type"] + else: + raise ValueError(f"config.json found but missing 'model_type' for model: {model}") + except Exception as e: + raise ValueError(f"Failed to read config.json for model: {model}. Error: {e}") from e + else: + raise ValueError( + f"Could not determine model_type for model: {model}. " + f"Model is not in standard transformers format and does not have model_index.json. " + f"Please ensure the model has proper configuration files with 'model_type' field" + ) + + default_config_path = current_omni_platform.get_default_stage_config_path() + model_type_str = f"{model_type}.yaml" + complete_config_path = PROJECT_ROOT / default_config_path / model_type_str + if os.path.exists(complete_config_path): + return str(complete_config_path) + + # Fall back to default config + stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" + stage_config_path = PROJECT_ROOT / stage_config_file + if not os.path.exists(stage_config_path): + return None + return str(stage_config_path) + + +def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list: + """Load stage configurations from model's default config file. + + Loads stage configurations based on the model type and device type. + First tries to load a device-specific YAML file from stage_configs/{device_type}/ + directory. If not found, falls back to the default config file. + + Args: + model: Model name or path (used to determine model_type) + + Returns: + List of stage configuration dictionaries + + Raises: + FileNotFoundError: If no stage config file exists for the model type + """ + if base_engine_args is None: + base_engine_args = {} + stage_config_path = resolve_model_config_path(model) + if stage_config_path is None: + return [] + stage_configs = load_stage_configs_from_yaml(config_path=stage_config_path, base_engine_args=base_engine_args) + return stage_configs + + +def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None = None) -> list: + """Load stage configurations from a YAML file. + + Args: + config_path: Path to the YAML configuration file + + Returns: + List of stage configuration dictionaries from the file's stage_args + """ + if base_engine_args is None: + base_engine_args = {} + config_data = OmegaConf.load(config_path) + stage_args = config_data.stage_args + global_async_chunk = config_data.get("async_chunk", False) + # Convert any nested dataclass objects to dicts before creating OmegaConf + base_engine_args = _convert_dataclasses_to_dict(base_engine_args) + base_engine_args = OmegaConf.create(base_engine_args) + for stage_arg in stage_args: + base_engine_args_tmp = base_engine_args.copy() + # Update base_engine_args with stage-specific engine_args if they exist + if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None: + base_engine_args_tmp = OmegaConf.merge(base_engine_args_tmp, stage_arg.engine_args) + stage_type = getattr(stage_arg, "stage_type", "llm") + if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None and stage_type != "diffusion": + runtime_cfg = stage_arg.runtime + max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) + base_engine_args_tmp["max_num_seqs"] = max_batch_size + base_engine_args_tmp.async_chunk = global_async_chunk + stage_arg.engine_args = base_engine_args_tmp + return stage_args + + +def get_final_stage_id_for_e2e( + output_modalities: list[str] | None, default_modalities: list[str], stage_list: list +) -> int: + """Get the final stage id for e2e. + + Args: + stage_list: List of stage configurations + + Returns: + Final stage id for e2e + """ + last_stage_id = len(stage_list) - 1 + if output_modalities is not None: + prompt_modalities = [] + for modality in output_modalities: + if modality not in default_modalities: + logger.warning(f"Invalid output modality: {modality}, ignoring it") + # TODO: if user specifies unsupported modalities, invalid it and raise an error + continue + prompt_modalities.append(modality) + output_modalities = prompt_modalities + else: + output_modalities = default_modalities + + try: + for _sid in range(last_stage_id, -1, -1): + if ( + getattr(stage_list[_sid], "final_output", False) + and stage_list[_sid].final_output_type in output_modalities + ): + final_stage_id_for_e2e = _sid + break + if final_stage_id_for_e2e < 0: + final_stage_id_for_e2e = last_stage_id + except Exception as e: + logger.debug( + "[Orchestrator] Failed to determine final stage for E2E; \ + falling back to last: %s", + e, + exc_info=True, + ) + final_stage_id_for_e2e = last_stage_id + + return final_stage_id_for_e2e diff --git a/vllm_omni/inputs/__init__.py b/vllm_omni/inputs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py new file mode 100644 index 0000000000000000000000000000000000000000..3684744bad5a43275abac9fcff6b1cf2bf7d6dae --- /dev/null +++ b/vllm_omni/inputs/data.py @@ -0,0 +1,284 @@ +import copy +import pprint +from dataclasses import asdict, dataclass, field +from typing import Any, TypeAlias + +from vllm import PromptType, SamplingParams + +from vllm_omni.lora.request import LoRARequest + +try: + from typing import NotRequired +except ImportError: + # Python < 3.11: use typing_extensions + from typing_extensions import NotRequired + +import torch +from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokenInputs, TokensPrompt + + +class OmniTextPrompt(TextPrompt): + """Text prompt with optional embeddings and additional information. + + Extends TextPrompt to support prompt embeddings and additional + information payloads for direct transfer between pipeline stages. + + Attributes: + prompt_embeds: Optional tensor containing prompt embeddings + additional_information: Optional dictionary containing additional + information (tensors or lists) to pass along with the prompt + """ + + negative_prompt: NotRequired[str] + prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[torch.Tensor] + additional_information: NotRequired[dict[str, Any]] + + +class OmniTokensPrompt(TokensPrompt): + """Tokens prompt with optional embeddings and additional information. + + Extends TokensPrompt to support prompt embeddings and additional + information payloads for direct transfer between pipeline stages. + + Attributes: + prompt_embeds: Optional tensor containing prompt embeddings + additional_information: Optional dictionary containing additional + information (tensors or lists) to pass along with the prompt + """ + + negative_prompt: NotRequired[str] + prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[list[torch.Tensor] | None] + """The embeddings of the prompt.""" + + # New: optional additional information dictionary + # Values may be torch.Tensor or list + additional_information: NotRequired[dict[str, Any]] + + +class OmniTokenInputs(TokenInputs): + """Token inputs with optional embeddings and additional information. + + Extends TokenInputs to support prompt embeddings and additional + information payloads for direct transfer between pipeline stages. + + Attributes: + prompt_embeds: Optional tensor containing prompt embeddings + aligned with token IDs + additional_information: Optional dictionary containing additional + information (tensors or lists) to pass along with the inputs + """ + + # New: optional prompt embeddings aligned with token ids + negative_prompt: NotRequired[str] + prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[list[torch.Tensor] | None] + + # New: optional additional information dictionary + # Values may be torch.Tensor or list + additional_information: NotRequired[dict[str, Any]] + + +class OmniEmbedsPrompt(EmbedsPrompt): + """Embeddings prompt with optional additional information. + + Extends EmbedsPrompt to support additional information payloads + for direct transfer between pipeline stages. + + Attributes: + prompt_embeds: Optional tensor containing prompt embeddings + additional_information: Optional dictionary containing additional + information (tensors or lists) to pass along with the prompt + """ + + # New: optional prompt embeddings aligned with token ids + prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[list[torch.Tensor] | None] + + # New: optional additional information dictionary + # Values may be torch.Tensor or list + additional_information: NotRequired[dict[str, Any]] + + +# Must ensure that all additional prompt types are inherited from vLLM prompt types +# Because TypedDict doesn't support isinstance and are dict. Cannot distinguish them in runtime. +# Inheritance ensure that there are only additional fields but not removing fields--safe to route to LLM.generate() +OmniSingletonPrompt: TypeAlias = str | OmniTextPrompt | OmniTokensPrompt | OmniEmbedsPrompt +"""Omni singleton prompt type extending vLLM's SingletonPrompt with additional fields.""" + +OmniPromptType: TypeAlias = PromptType | OmniTextPrompt | OmniTokensPrompt | OmniEmbedsPrompt + + +def token_inputs_omni( + prompt_token_ids: list[int], + prompt: str | None = None, + cache_salt: str | None = None, + prompt_embeds: torch.Tensor | None = None, + additional_information: dict[str, Any] | None = None, +) -> OmniTokenInputs: + """Construct token inputs with optional embeddings and metadata. + + Creates an OmniTokenInputs object with token IDs and optional + embeddings and additional information for pipeline stage transfer. + + Args: + prompt_token_ids: List of token IDs for the prompt + prompt: Optional prompt string + cache_salt: Optional cache salt for prefix caching + prompt_embeds: Optional tensor containing prompt embeddings + additional_information: Optional dictionary containing additional + information (tensors or lists) + + Returns: + OmniTokenInputs instance with the provided data + """ + inputs = OmniTokenInputs(type="token", prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + if prompt_embeds is not None: + inputs["prompt_embeds"] = prompt_embeds + if additional_information is not None: + inputs["additional_information"] = additional_information + + return inputs + + +@dataclass +class OmniDiffusionSamplingParams: + """ + The collection of sampling parameters passed to diffusion pipelines. + + This dataclass contains all information needed during the diffusion pipeline + execution, allowing methods to update specific components without needing + to manage numerous individual parameters. + """ + + # Additional text-related parameters + max_sequence_length: int | None = None + prompt_template: dict[str, Any] | None = None + do_classifier_free_guidance: bool = False + + # Batch info + num_outputs_per_prompt: int = 1 + seed: int | None = None + generator: torch.Generator | list[torch.Generator] | None = None + + # layered info + layers: int = 4 + + # cfg info + cfg_normalize: bool = False + + # caption language + use_en_prompt: bool = False + + # different bucket in (640, 1024) to determine the condition and output resolution + resolution: int = 640 + + # Tracking if embeddings are already processed + is_prompt_processed: bool = False + + # Latent tensors + latents: torch.Tensor | None = None + raw_latent_shape: torch.Tensor | None = None + noise_pred: torch.Tensor | None = None + image_latent: torch.Tensor | None = None + + # Latent dimensions + height_latents: list[int] | int | None = None + width_latents: list[int] | int | None = None + num_frames: int = 1 # Default for image models + num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus + + # Original dimensions (before VAE scaling) + height: int | None = None + width: int | None = None + fps: int | None = None + height_not_provided: bool = False + width_not_provided: bool = False + + # Timesteps + timesteps: torch.Tensor | None = None + timestep: torch.Tensor | float | int | None = None + step_index: int | None = None + boundary_ratio: float | None = None + + # Scheduler parameters + num_inference_steps: int = 50 + guidance_scale: float = 0.0 + guidance_scale_provided: bool = False + guidance_scale_2: float | None = None + guidance_rescale: float = 0.0 + eta: float = 0.0 + sigmas: list[float] | None = None + + true_cfg_scale: float | None = None # qwen-image specific now + + n_tokens: int | None = None + extra_step_kwargs: dict[str, Any] = field(default_factory=dict) + + # [Omni] KV Cache Transfer, for bagel model now + past_key_values: Any | None = None # Injected KV Cache + kv_metadata: dict[str, Any] | None = None # Metadata for KV Cache (e.g., kv_lens, ropes) + need_kv_receive: bool = True # Flag to indicate if this request expects KV transfer + + # Component modules + modules: dict[str, Any] = field(default_factory=dict) + + return_trajectory_latents: bool = False + return_trajectory_decoded: bool = False + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_latents: torch.Tensor | None = None + + # Extra parameters that might be needed by specific pipeline implementations + extra_args: dict[str, Any] = field(default_factory=dict) + + # Misc + save_output: bool = True + return_frames: bool = False + + # LoRA + lora_request: LoRARequest | None = None + lora_scale: float = 1.0 + + # STA parameters + STA_param: list | None = None + is_cfg_negative: bool = False + mask_search_final_result_pos: list[list] | None = None + mask_search_final_result_neg: list[list] | None = None + + # VSA parameters + VSA_sparsity: float = 0.0 + # perf_logger: PerformanceLogger | None = None + + # stage logging + # logging_info: PipelineLoggingInfo = field(default_factory=PipelineLoggingInfo) + + # profile + profile: bool = False + num_profiled_timesteps: int = 8 + + # debugging + debug: bool = False + + # results + output: torch.Tensor | None = None + + @property + def batch_size(self): + # This class is changed to only represent a single prompt request + # Only adjust batch size for number of videos per prompt + return self.num_outputs_per_prompt + + def __str__(self): + return pprint.pformat(asdict(self), indent=2, width=120) + + def clone(self) -> "OmniDiffusionSamplingParams": + return copy.deepcopy(self) + + +OmniSamplingParams: TypeAlias = SamplingParams | OmniDiffusionSamplingParams diff --git a/vllm_omni/inputs/parse.py b/vllm_omni/inputs/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe46ff537773a849ffddcc15410f42fe517b6e8 --- /dev/null +++ b/vllm_omni/inputs/parse.py @@ -0,0 +1,42 @@ +from vllm.inputs.parse import ( + ParsedEmbedsPrompt, + ParsedSingletonPrompt, + ParsedStrPrompt, + ParsedTextPrompt, + ParsedTokensPrompt, + SingletonPrompt, +) + + +def parse_singleton_prompt_omni(prompt: SingletonPrompt) -> ParsedSingletonPrompt: + """Parse a singleton prompt into a typed parsed prompt. + + Handles omni-specific prompt types including tokens prompts with + embeddings and additional information. Supports string, text, + tokens, and embeddings prompts. + + Args: + prompt: Singleton prompt to parse. Can be a string, TextPrompt, + TokensPrompt (with optional prompt_embeds and additional_information), + or EmbedsPrompt. + + Returns: + ParsedSingletonPrompt containing the parsed prompt with type information + + Raises: + TypeError: If the prompt type is not supported + """ + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + # Type ignores are because mypy does not correctly infer the TypedDicts + # Pyright does succeed. + # Priority tokens: When both tokens and embeds exist, keep both and + # follow the tokens path + if "prompt_token_ids" in prompt: + return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item] + elif "prompt_embeds" in prompt: + return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item] + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) + raise TypeError("inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..d04c66331d1eaea2e347cf818e3875f7bea7b6f3 --- /dev/null +++ b/vllm_omni/inputs/preprocess.py @@ -0,0 +1,144 @@ +from typing import Any + +from typing_extensions import assert_never +from vllm.inputs.data import SingletonInputs, SingletonPrompt +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.multimodal.inputs import MultiModalInputs, MultiModalUUIDDict + +from vllm_omni.inputs.data import ( + OmniTextPrompt, + OmniTokenInputs, + OmniTokensPrompt, + token_inputs_omni, +) +from vllm_omni.inputs.parse import parse_singleton_prompt_omni + +logger = init_logger(__name__) + + +class OmniInputPreprocessor(InputPreprocessor): + """Input preprocessor for omni models. + + Extends the base InputPreprocessor to handle omni-specific input + types including prompt embeddings and additional information payloads. + Supports processing tokens, embeddings, text, and multimodal inputs. + """ + + def _process_text( + self, + parsed_content: OmniTextPrompt, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> OmniTokenInputs | MultiModalInputs: + prompt_text = parsed_content["prompt"] + + inputs: OmniTokenInputs | MultiModalInputs + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs") or {}, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + prompt_embeds = parsed_content.get("prompt_embeds") + if prompt_embeds is not None: + inputs["prompt_embeds"] = prompt_embeds + additional_information = parsed_content.get("additional_information") + if additional_information is not None: + inputs["additional_information"] = additional_information + else: + prompt_token_ids = self._tokenize_prompt( + prompt_text, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs_omni( + prompt_token_ids, + prompt_embeds=parsed_content.get("prompt_embeds"), + additional_information=parsed_content.get("additional_information"), + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + def _process_tokens( + self, + parsed_content: OmniTokensPrompt, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> OmniTokenInputs | MultiModalInputs: + prompt_token_ids = self._truncate_inputs(parsed_content["prompt_token_ids"], tokenization_kwargs) + prompt_embeds = parsed_content.get("prompt_embeds") + additional_information = parsed_content.get("additional_information") + + inputs: OmniTokenInputs | MultiModalInputs + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs") or {}, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + if prompt_embeds is not None: + inputs["prompt_embeds"] = prompt_embeds + if additional_information is not None: + inputs["additional_information"] = additional_information + else: + inputs = token_inputs_omni( + prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, + additional_information=additional_information, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + def _prompt_to_llm_inputs( + self, + prompt: SingletonPrompt, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> SingletonInputs: + """ + Extract the singleton inputs from a prompt. + + Arguments: + + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + * return_mm_hashes: whether to return multimodal hashes + + Returns: + + * Input container compatible with vLLM's singleton prompt handling. + """ + parsed = parse_singleton_prompt_omni(prompt) + + if parsed["type"] == "tokens": + return self._process_tokens( + parsed["content"], + mm_uuids=mm_uuids, + ) + if parsed["type"] == "text": + return self._process_text( + parsed["content"], + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + if parsed["type"] == "str": + return self._process_text( + OmniTextPrompt(prompt=parsed["content"]), + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + assert_never(parsed) diff --git a/vllm_omni/logger.py b/vllm_omni/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..1bcbc027b9328424fdce4853e0a338fc5eb99df4 --- /dev/null +++ b/vllm_omni/logger.py @@ -0,0 +1,22 @@ +import logging + +from vllm.logger import init_logger + + +def _configure_vllm_omni_root_logger(): + """ + Configure the root logger for vllm_omni to propagate to vllm's root logger. + """ + vllm_root = logging.getLogger("vllm") + vllm_omni_root = logging.getLogger("vllm_omni") + vllm_omni_root.handlers = [] + + vllm_omni_root.parent = vllm_root + + vllm_omni_root.propagate = True + + vllm_omni_root.setLevel(logging.NOTSET) + + +_configure_vllm_omni_root_logger() +init_logger(__name__) diff --git a/vllm_omni/lora/__init__.py b/vllm_omni/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/lora/request.py b/vllm_omni/lora/request.py new file mode 100644 index 0000000000000000000000000000000000000000..55eb02ba4473764006dbbcbbdaf8b032ee349589 --- /dev/null +++ b/vllm_omni/lora/request.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# for now, it suffices to use vLLM's implementation directly +# as this is a user-facing variable, defined here to so that user can directly import LoRARequest from vllm_omni +from vllm.lora.request import LoRARequest + +__all__ = ["LoRARequest"] diff --git a/vllm_omni/lora/utils.py b/vllm_omni/lora/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9404d080f6ca2525391f8d80631bd256e3154718 --- /dev/null +++ b/vllm_omni/lora/utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import hashlib + + +def stable_lora_int_id(lora_path: str) -> int: + """Return a deterministic positive integer ID for a LoRA adapter. + + vLLM uses `lora_int_id` as the adapter's cache key. Python's built-in + `hash()` is intentionally randomized per process (PYTHONHASHSEED), which + makes it unsuitable for persistent IDs. This helper derives a stable + 63-bit positive integer from the adapter path. + """ + digest = hashlib.sha256(lora_path.encode("utf-8")).digest() + value = int.from_bytes(digest[:8], byteorder="big", signed=False) & ((1 << 63) - 1) + return value or 1 + + +__all__ = ["stable_lora_int_id"] diff --git a/vllm_omni/model_executor/__init__.py b/vllm_omni/model_executor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/model_executor/custom_process_mixin.py b/vllm_omni/model_executor/custom_process_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..9b2942086a9f9a911f371f792c5c067304f85548 --- /dev/null +++ b/vllm_omni/model_executor/custom_process_mixin.py @@ -0,0 +1,44 @@ +from collections.abc import Callable + +import torch + + +class CustomProcessMixin: + """ + Mixin class for all stages in the Omni model. + """ + + def set_custom_preprocess(self, preprocess_fn: Callable) -> None: + """ + Set a preprocess function for the stage. + Args: + preprocess_fn: The preprocess function to register. + """ + self.preprocess = preprocess_fn + + def set_custom_postprocess(self, postprocess_fn: Callable) -> None: + """ + Set a postprocess function for the stage. + Args: + postprocess_fn: The postprocess function to register. + """ + self.postprocess = postprocess_fn + + def preprocess( + self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **input_dict: object + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + """ + Process the input_ids and input_embeds for the given input_dict. + Returns the processed input_ids, input_embeds, and the input_dict. + If the stage don't applicable, return the original input_ids, input_embeds, and an empty dict. + """ + raise NotImplementedError("Preprocess is not implemented for this stage.") + + def postprocess(self, model_output, **info_dict: object): + """ + Postprocess the model output. + Returns the postprocessed model output and the save dictionary. + Args: + model_output: The model output to postprocess. + """ + raise NotImplementedError("Postprocess is not implemented for this stage.") diff --git a/vllm_omni/model_executor/layers/__init__.py b/vllm_omni/model_executor/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/model_executor/layers/rotary_embedding/__init__.py b/vllm_omni/model_executor/layers/rotary_embedding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02ebeb194034f4f6aa24a9c8f1ac668f57aef50a --- /dev/null +++ b/vllm_omni/model_executor/layers/rotary_embedding/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .mrope import OmniMRotaryEmbedding + +__all__ = ["OmniMRotaryEmbedding"] diff --git a/vllm_omni/model_executor/layers/rotary_embedding/mrope.py b/vllm_omni/model_executor/layers/rotary_embedding/mrope.py new file mode 100644 index 0000000000000000000000000000000000000000..463e555073c2ea479ac79ace53847fc0d7bf755c --- /dev/null +++ b/vllm_omni/model_executor/layers/rotary_embedding/mrope.py @@ -0,0 +1,554 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Omni-extended MRotaryEmbedding with multimodal position computation methods. + +This module extends the upstream vLLM MRotaryEmbedding with additional methods +for computing input positions for various multimodal scenarios including: +- Image/Video inputs (Qwen2.5-VL style) +- Audio inputs (Qwen2.5-Omni style) +- Audio-in-video interleaved inputs +- GLM4V style inputs +""" + +import itertools + +import torch +from transformers import PretrainedConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding + +logger = init_logger(__name__) + + +class OmniMRotaryEmbedding(MRotaryEmbedding): + """Omni-extended MRotaryEmbedding with multimodal position computation. + + Extends the upstream MRotaryEmbedding with additional class methods for + computing input positions for various multimodal scenarios. + """ + + @classmethod + def get_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[list[list[int]], int]: + """Get mrope input positions and delta value.""" + + image_grid_thw = [] if image_grid_thw is None else image_grid_thw + video_grid_thw = [] if video_grid_thw is None else video_grid_thw + second_per_grid_ts = [] if second_per_grid_ts is None else second_per_grid_ts + + llm_positions, mrope_position_delta = cls.get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + return llm_positions.tolist(), mrope_position_delta + + @classmethod + def get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + from vllm.transformers_utils.config import thinker_uses_mrope + + if thinker_uses_mrope(hf_config): + return cls._omni_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + elif hf_config.model_type in ["glm4v", "glm4v_moe"]: + return cls._glm4v_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) + else: + return cls._vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + @classmethod + def _glm4v_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + context_len: int = 0, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby(enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + + @classmethod + def _vl_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere(input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = ( + ( + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @classmethod + def _omni_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + thinker_config = hf_config.thinker_config + try: + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + except Exception: + logger.info("Multimodal token idx changed!") + audio_token_id = thinker_config.audio_token_id + image_token_id = thinker_config.image_token_id + video_token_id = thinker_config.video_token_id + + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges(t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) * [audio_token_id]) + audio_start_idx = ( + start_idx if len(audio_llm_pos_ids_list) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend((pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + llm_pos_ids_list[-1].max() + 1 + ).split(1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @staticmethod + def _get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, + ) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index_tensor = ( + torch.Tensor(t_index).to(llm_grid_h.device).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + ) + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + @staticmethod + def _split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: list[int] | torch.Tensor, + video_second_per_grid_t: float, + ) -> list[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, "tokens_per_second", 25) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges(t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates diff --git a/vllm_omni/model_executor/model_loader/__init__.py b/vllm_omni/model_executor/model_loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/model_executor/model_loader/weight_utils.py b/vllm_omni/model_executor/model_loader/weight_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7432ad9a2a4676c8829975345650ab78e8dd302a --- /dev/null +++ b/vllm_omni/model_executor/model_loader/weight_utils.py @@ -0,0 +1,72 @@ +import time +from pathlib import Path + +import huggingface_hub +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import DisabledTqdm, get_lock + +if envs.VLLM_USE_MODELSCOPE: + from modelscope.hub.snapshot_download import snapshot_download +else: + from huggingface_hub import snapshot_download + +logger = init_logger(__name__) + + +def download_weights_from_hf_specific( + model_name_or_path: str, + cache_dir: str | None, + allow_patterns: list[str], + revision: str | None = None, + ignore_patterns: str | list[str] | None = None, +) -> str: + """Download model weights from Hugging Face Hub. Users can specify the + allow_patterns to download only the necessary weights. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (list[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + ignore_patterns (Optional[Union[str, list[str]]]): The patterns to + filter out the weight files. Files matched by any of the patterns + will be ignored. + + Returns: + str: The path to the downloaded model weights. + """ + assert len(allow_patterns) > 0 + local_only = huggingface_hub.constants.HF_HUB_OFFLINE + download_kwargs = {"tqdm_class": DisabledTqdm} if not envs.VLLM_USE_MODELSCOPE else {} + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + start_time = time.perf_counter() + for allow_pattern in allow_patterns: + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_pattern, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + revision=revision, + local_files_only=local_only, + **download_kwargs, + ) + # If we have downloaded weights for this allow_pattern, + # we don't need to check the rest. + if any(Path(hf_folder).glob(allow_pattern)): + break + time_taken = time.perf_counter() - start_time + if time_taken > 0.5: + logger.info( + "Time spent downloading weights for %s: %.6f seconds", + model_name_or_path, + time_taken, + ) + return hf_folder diff --git a/vllm_omni/model_executor/models/__init__.py b/vllm_omni/model_executor/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b2629b4a55663b9fd7a953d1442e22e5e93b7f5 --- /dev/null +++ b/vllm_omni/model_executor/models/__init__.py @@ -0,0 +1,4 @@ +from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration +from .registry import OmniModelRegistry # noqa: F401 + +__all__ = ["Qwen3OmniMoeForConditionalGeneration"] diff --git a/vllm_omni/model_executor/models/output_templates.py b/vllm_omni/model_executor/models/output_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed20980651cf7a758475d98093f25e88df4ba56 --- /dev/null +++ b/vllm_omni/model_executor/models/output_templates.py @@ -0,0 +1,13 @@ +from typing import NamedTuple + +import torch +from vllm.sequence import IntermediateTensors + + +class OmniOutput(NamedTuple): + """Output from the merged Omni model containing both text and audio.""" + + text_hidden_states: torch.Tensor + multimodal_outputs: dict | None = None + intermediate_tensors: IntermediateTensors | None = None + next_token_id: torch.Tensor | None = None diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py b/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/audio_length.py b/vllm_omni/model_executor/models/qwen2_5_omni/audio_length.py new file mode 100644 index 0000000000000000000000000000000000000000..4162f88e57f98f06daa7beaebb6b39393258c5a5 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni/audio_length.py @@ -0,0 +1,69 @@ +"""Audio length helpers (pure-Python). + +These utilities are used to keep code2wav (codec -> mel -> wav) lengths aligned +when applying a mel-frame cap. +""" + +from __future__ import annotations + + +def resolve_max_mel_frames(max_mel_frames: int | None, *, default: int = 30000) -> int: + """Resolve max mel frames from an explicit value or default. + + Args: + max_mel_frames: Explicit value to use. If None, uses `default`. + default: Default value to use when `max_mel_frames` is None. + + Returns: + The resolved max mel frames value. + """ + if max_mel_frames is not None: + return int(max_mel_frames) + return int(default) + + +def cap_and_align_mel_length( + *, + code_len: int, + repeats: int, + max_mel_frames: int | None, + default_max_mel_frames: int = 30000, +) -> tuple[int, int]: + """Compute a (target_code_len, target_mel_len) pair. + + - `mel_len` is always a multiple of `repeats` (codec expansion factor). + - If `max_mel_frames` is None, uses `default_max_mel_frames`. + - If `max_mel_frames` <= 0, no cap is applied (mel_len == code_len * repeats). + - If `max_mel_frames` is smaller than `repeats` and code_len > 0, we still + return at least one codec token worth of mel frames (mel_len == repeats) + so downstream repeat-interleave stays valid. + """ + code_len = int(code_len) + repeats = int(repeats) + if repeats <= 0: + raise ValueError(f"repeats must be > 0, got {repeats}") + if code_len <= 0: + return 0, 0 + + if max_mel_frames is None: + max_mel_frames = int(default_max_mel_frames) + else: + max_mel_frames = int(max_mel_frames) + + maximum_duration = int(code_len * repeats) + if max_mel_frames > 0: + target_duration = min(maximum_duration, max_mel_frames) + else: + target_duration = maximum_duration + + # Align down to repeats; then ensure we keep at least one codec token. + target_duration = (target_duration // repeats) * repeats + if target_duration <= 0: + target_duration = min(maximum_duration, repeats) + + target_code_len = target_duration // repeats + if target_code_len <= 0: + target_code_len = 1 + target_duration = repeats + + return int(target_code_len), int(target_duration) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..936a18427cc08e1899fce606c0e01cab7d256369 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -0,0 +1,1043 @@ +import glob +import os +from collections.abc import Iterable +from functools import cached_property + +import numpy as np +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniConfig, + Qwen2_5OmniTalkerConfig, + Qwen2_5OmniThinkerConfig, +) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, +) +from vllm.model_executor.models.utils import init_vllm_registered_model, maybe_prefix +from vllm.model_executor.models.vision import ( + get_llm_pos_ids_for_vision, +) + +# from vllm.model_executor.models.qwen2_code2wav_dit import Qwen2Code2wav +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFeatureSpec +from vllm.sequence import IntermediateTensors +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights, split_list_into_ranges +from vllm_omni.platforms import current_omni_platform + +TALKER_CODEC_EOS_TOKEN_ID = 8294 +TALKER_CODEC_BOS_TOKEN_ID = 8293 + + +logger = init_logger(__name__) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin, CustomProcessMixin +): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.has_preprocess = False + self.have_multimodal_outputs = True + config: Qwen2_5OmniConfig = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + # keep vllm_config for later submodule init + self.vllm_config = vllm_config + + # Initialize thinker components + thinker_config: Qwen2_5OmniThinkerConfig = config.thinker_config + self.thinker_config = thinker_config + self.multimodal_config = multimodal_config + + # Initialize talker components + talker_config: Qwen2_5OmniTalkerConfig = config.talker_config + self.talker_config = talker_config + + self.model_stage = vllm_config.model_config.model_stage + if self.model_stage == "thinker": + # Initialize thinker model (multimodal processing) + self.thinker = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "thinker"), + hf_config=thinker_config, + # Use registry architecture key + architectures=["Qwen2_5OmniThinkerModel"], + ) + self.model = self.thinker + self.talker = None + self.token2wav = None + + elif self.model_stage == "talker": + # register the process function for the talker stage + self.has_preprocess = True + self.set_custom_preprocess(self.talker_preprocess) + self.thinker = None + # Initialize talker model wrapper (handles projection + LM) + self.talker = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "talker"), + hf_config=talker_config, + # Use registry architecture key + architectures=["Qwen2_5OmniTalkerModel"], + ) + self.talker.init_multi_modal(thinker_config) + self.model = self.talker + self.token2wav = None + # set suppress start id according to token2wav + t2w_token_end_id = getattr( + getattr(getattr(config, "token2wav_config", None), "dit_config", None), "num_embeds", None + ) + if t2w_token_end_id: + self.model.set_suppress_start_id(t2w_token_end_id + 1) + self.requires_raw_input_tokens = True + + elif self.model_stage == "code2wav": + self.thinker = None + self.talker = None + # Initialize token2wav (code->mel->wav) like thinker/talker + self.token2wav_config = getattr(config, "token2wav_config", None) + self.token2wav = None + if self.token2wav_config is not None: + self.token2wav = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "token2wav"), + hf_config=self.token2wav_config, + architectures=["Qwen2_5OmniToken2WavModel"], + ) + # voice resources (loaded on demand) + self._token2wav_conds: dict[str, torch.Tensor] = {} + self._token2wav_ref_mels: dict[str, torch.Tensor] = {} + self.model = self.token2wav + self.requires_raw_input_tokens = True + else: + raise ValueError("Invalid model stage") + + # Set up intermediate tensors + self.make_empty_intermediate_tensors = ( + (self.thinker.make_empty_intermediate_tensors) if self.model_stage == "thinker" else lambda: None + ) + + # -------------------- Device utilities -------------------- + @staticmethod + def _module_device(module: nn.Module) -> torch.device: + try: + return next(module.parameters()).device + except StopIteration: + # No parameters; fall back to buffers or cpu + for _, buf in module.named_buffers(recurse=True): + return buf.device + return torch.device("cpu") + + def move_submodules_to_devices( + self, + *, + thinker_device: str | torch.device | None = None, + talker_device: str | torch.device | None = None, + token2wav_device: str | torch.device | None = None, + ) -> None: + """Optionally move thinker/talker/token2wav to different devices. + + Example: + model.move_submodules_to_devices( + thinker_device='cuda:0', + talker_device='cuda:1', + token2wav_device='cpu', + ) + """ + if thinker_device is not None and self.thinker is not None: + self.thinker.to(thinker_device) + if talker_device is not None and self.talker is not None: + self.talker.to(talker_device) + if token2wav_device is not None and self.token2wav is not None: + self.token2wav.to(token2wav_device) + + @cached_property + def sampler(self): + if hasattr(self.model, "sampler"): + return self.model.sampler + return Sampler() + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + is_multimodal=None, + ) -> torch.Tensor: + if self.model_stage == "code2wav": + return torch.zeros_like(input_ids).reshape(-1, 1).repeat(1, self.vllm_config.model_config.get_hidden_size()) + return self.model.embed_input_ids( + input_ids=input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal + ) + + def embed_multimodal(self, **kwargs): + # Delegate to thinker model for multimodal processing + return self.model.embed_multimodal(**kwargs) + + def last_index_of(self, list, value): + return len(list) - 1 - list[::-1].index(value) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + generate_audio: bool = True, + voice_type: str = "Chelsie", + codec: torch.Tensor | None = None, + sampling_metadata: SamplingMetadata | None = None, + logits_index: int | None = None, + sampler=None, + additional_information: dict[str, object] | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors | OmniOutput: + """ + Workflow: + 1) Thinker: multimodal understanding → text hidden states. + 2) If audio requested and codec not provided, use talker to derive codec. + 3) If audio requested (or codec provided), use token2wav to synthesize waveform. + 4) Return text hidden states (and audio when applicable). + """ + if self.model_stage == "thinker": + # Normalize to batched inputs if caller provides 1D/2D unbatched tensors + # TODO: Remove this hack when NPU supports batched inputs properly + added_batch_dim = False + if input_ids is not None and input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + added_batch_dim = True + if positions is not None and positions.ndim == 1: + positions = positions.unsqueeze(0) + added_batch_dim = True + if inputs_embeds is not None and inputs_embeds.ndim == 2: + inputs_embeds = inputs_embeds.unsqueeze(0) + added_batch_dim = True + thinker_dev = self._module_device(self.thinker) + + # if input_ids is None, set it to a zero tensor, in the length of the + # same as the embedding seq length + if input_ids is None: + input_ids = torch.zeros(inputs_embeds.shape[1], dtype=torch.long, device=thinker_dev).unsqueeze( + 0 + ) # (1, 0) + added_batch_dim = True + + # 1) Thinker (ensure inputs on thinker's device) + if input_ids is not None and input_ids.device != thinker_dev: + input_ids = input_ids.to(thinker_dev) + if positions is not None and positions.device != thinker_dev: + positions = positions.to(thinker_dev) + if inputs_embeds is not None and inputs_embeds.device != thinker_dev: + inputs_embeds = inputs_embeds.to(thinker_dev) + + if current_omni_platform.is_npu(): + # TODO: remove this hack when NPU supports batched inputs properly + thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids + # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it + if positions.ndim == 2 and positions.shape[0] == 3: + thinker_positions = positions # MRoPE positions, keep as is + else: + thinker_positions = positions[0] if positions.ndim > 1 else positions + thinker_inputs_embeds = ( + inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds + ) + else: + # Squeeze back if we added batch dim earlier + thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids + # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it + if positions.ndim == 2 and positions.shape[0] == 3: + thinker_positions = positions # MRoPE positions, keep as is + elif added_batch_dim: + thinker_positions = positions[0] + else: + thinker_positions = positions + thinker_inputs_embeds = ( + inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds + ) + + # Run thinker + thinker_output = self.thinker( + input_ids=thinker_input_ids, + positions=thinker_positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=thinker_inputs_embeds, + **kwargs, + ) + + if isinstance(thinker_output, tuple): + embeds, text_hidden_states = thinker_output + else: + text_hidden_states = thinker_output + + # Text-only path + return OmniOutput( + text_hidden_states=(text_hidden_states.reshape(-1, text_hidden_states.shape[-1])), + multimodal_outputs=None, + ) + + # 2) Talker (if codec not provided) + if self.model_stage == "talker": + # mock data for profile + if input_ids is None: + input_ids = torch.zeros(inputs_embeds.shape[0], dtype=torch.long, device=inputs_embeds.device) + self.thinker_reply_part = torch.zeros_like(inputs_embeds) + + # TODO(Peiqi): temporal hack here to support voice_type. + if not hasattr(self, "voice_type"): + self.voice_type = voice_type + + # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it + if positions.ndim == 2 and positions.shape[0] == 3: + talker_positions = positions # MRoPE positions, keep as is + else: + talker_positions = positions[0] + + with torch.inference_mode(): + talker_hidden = self.talker( + input_ids=input_ids, + positions=talker_positions, + inputs_embeds=inputs_embeds, + ) + + if sampling_metadata is not None: + # the padding token id is set to text model's pad token id, + # which do not match with the talker model's word embedding size + sampling_metadata.prompt_token_ids[sampling_metadata.prompt_token_ids == 152064] = 8448 + + return OmniOutput( + text_hidden_states=talker_hidden, + multimodal_outputs=None, + ) + + if self.model_stage == "code2wav": + code = ( + input_ids + if input_ids is not None + else torch.zeros( + inputs_embeds.shape[0], + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + + code = code[:-1] if code[-1] == TALKER_CODEC_EOS_TOKEN_ID else code + code = code[1:] if code[0] == TALKER_CODEC_BOS_TOKEN_ID else code + + audio_tensor = self.generate_audio(code, voice_type) + return OmniOutput(text_hidden_states=None, multimodal_outputs={"model_outputs": audio_tensor}) + + return OmniOutput( + text_hidden_states=torch.cat( + [ + torch.zeros( + [inputs_embeds.shape[0], self.talker.config.hidden_size], + dtype=torch.bfloat16, + ).to(self._module_device(self.model)), + self.talker.thinker_to_talker_proj( + self.talker.embed_input_ids( + torch.tensor([TALKER_CODEC_BOS_TOKEN_ID, TALKER_CODEC_EOS_TOKEN_ID]) + .to(torch.bfloat16) + .to(self._module_device(self.model)) + ) + )[0], + ], + dim=0, + ), + multimodal_outputs=None, + ) + + def get_mrope_input_positions( + self, + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec] | None = None, + *, + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = get_llm_pos_ids_for_vision( + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) * [audio_token_id]) + audio_start_idx = ( + start_idx if len(audio_llm_pos_ids_list) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend((pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + llm_pos_ids_list[-1].max() + 1 + ).split(1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + def generate_audio(self, code, voice_type): + token2wav_dev = self._module_device(self.token2wav) + if isinstance(code, torch.Tensor): + code_tensor = code.to(dtype=torch.long, device=token2wav_dev) + else: + code_tensor = torch.as_tensor(code, dtype=torch.long, device=token2wav_dev) + if code_tensor.ndim == 2 and code_tensor.shape[0] == 1: + code_tensor = code_tensor.squeeze(0) + + audio_tensor = self._codec_to_audio(code_tensor, voice_type) + + return audio_tensor + + def _load_talker_embedding( + self, + ) -> torch.nn.Embedding: + return self.talker.language_model.model.embed_tokens + + def _init_special_tokens_embeddings( + self, + ): + # talker embeddings + self.talker_embedding = self._load_talker_embedding() + + # embed_text_bos_token + self.tts_text_spk_token_ids = { + # M02: Male voice with standard Mandarin and a slight northern accent + "m02": 151870, + "Ethan": 151870, + # F030: Your anime-styled virtual girlfriend + "f030": 151872, + "Chelsie": 151872, + } + self.default_tts_text_spk_type = list(self.tts_text_spk_token_ids.keys())[0] + self.tts_text_spk_token_ids["prefix_caching"] = 151870 + + talker_hf_config = self.talker_config + if hasattr(talker_hf_config, "talker_config"): + talker_hf_config = talker_hf_config.talker_config + + self.embed_text_bos_token = self.thinker_embedding( + torch.tensor( + [talker_hf_config.tts_text_start_token_id], + dtype=torch.long, + device=self._module_device(self.talker), + ) + ) + self.embed_text_spk_tokens = { + key: self.thinker_embedding( + torch.tensor( + [value], + dtype=torch.long, + device=self._module_device(self.talker), + ) + ) + for key, value in self.tts_text_spk_token_ids.items() + } + self.embed_text_eos_token = self.thinker_embedding( + torch.tensor( + [talker_hf_config.tts_text_end_token_id], + dtype=torch.long, + device=self._module_device(self.talker), + ) + ) + self.embed_text_pad_token = self.thinker_embedding( + torch.tensor( + [talker_hf_config.tts_text_pad_token_id], + dtype=torch.long, + device=self._module_device(self.talker), + ) + ) + self.embed_codec_bos_token = self.talker_embedding( + torch.tensor( + [talker_hf_config.tts_codec_start_token_id], + dtype=torch.long, + device=self._module_device(self.talker), + ) + ) + self.embed_codec_eos_token = self.talker_embedding( + torch.tensor( + [talker_hf_config.tts_codec_end_token_id], + dtype=torch.long, + device=self._module_device(self.talker), + ) + ) + self.embed_codec_pad_token = self.talker_embedding( + torch.tensor( + [talker_hf_config.tts_codec_pad_token_id], + dtype=torch.long, + device=self._module_device(self.talker), + ) + ) + return set(["thinker_embedding.weight", "talker_embedding.weight"]) + + def _get_embed_text_spk_token(self, voice_type: str): + if voice_type not in self.embed_text_spk_tokens: + return self.embed_text_bos_token + return self.embed_text_spk_tokens[voice_type] + + def _get_text_spk_token_id(self, voice_type: str): + talker_hf_config = self.talker_config + if hasattr(talker_hf_config, "talker_config"): + talker_hf_config = talker_hf_config.talker_config + + if voice_type not in self.tts_text_spk_token_ids: + return talker_hf_config.tts_text_start_token_id + return self.tts_text_spk_token_ids[voice_type] + + def talker_preprocess( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor, + **info_dict: object, + ): + # Mixed-mode support: In a single step, both Prefill*n and Decode*n are supported. + # Rules: + # - Prefill segments are wrapped with special tokens: [BOS][PAD...][EOS] + # - Decode segments consist of a single non-special token. + # - If additional_information is provided (can be a list split by request or a + # concatenated tensor plus a list of shapes), then for each request, reconstruct + # the thinker→talker input embeddings for the Prefill segments; + # - For Decode segments, if per-request auxiliary decode embeddings are provided (optional), + # add them; otherwise, keep the original embedding. + + # Ensure we have base embeddings when only ids are provided + if input_embeds is None and input_ids is not None: + input_embeds = self.talker.embed_input_ids(input_ids) + + span_len = input_ids.shape[0] + if span_len > 1: + # prefill + return self.thinker_to_talker_process(input_ids, input_embeds, **info_dict) + else: + # decode + return self.thinker_to_talker_decode_one_step(input_ids, input_embeds, **info_dict) + + def thinker_to_talker_process( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor, + **info_dict: object, + ): + update_dict = {} + + prompt_embeds = info_dict.get("prompt_embeds") # Tensor [P,H] + thinker_result = info_dict.get("thinker_result") # Tensor [K,H] + prompt_token_ids = info_dict.get("prompt_token_ids") # list[int] + thinker_output_token_ids = info_dict.get("thinker_output_token_ids") # list[int] + + if not isinstance(prompt_embeds, torch.Tensor): + prompt_embeds = torch.zeros( + 0, self.talker.config.hidden_size, dtype=input_embeds.dtype, device=self._module_device(self.model) + ) + if not isinstance(thinker_result, torch.Tensor): + thinker_result = torch.zeros( + 0, self.talker.config.hidden_size, dtype=input_embeds.dtype, device=self._module_device(self.model) + ) + if not isinstance(prompt_token_ids, (list, torch.Tensor)): + prompt_token_ids = [] + if not isinstance(thinker_output_token_ids, (list, torch.Tensor)): + thinker_output_token_ids = [] + + # TODO(Peiqi): add voice_type support + req_input_ids, req_embeds = self._thinker_to_talker_prefill( + voice_type=self.voice_type, + output_prompt_embeds=thinker_result.to(input_embeds.dtype).to(self._module_device(self.model)), + output_token_ids=thinker_output_token_ids, + thinker_prompt_embeds=prompt_embeds.to(input_embeds.dtype).to(self._module_device(self.model)), + prompt_token_ids=prompt_token_ids, + ) + + if thinker_result.ndim == 2 and thinker_result.shape[0] > 0: + update_dict["thinker_reply_part"] = thinker_result[1:].detach().to("cpu").contiguous() + + return req_input_ids, req_embeds, update_dict + + def _thinker_to_talker_prefill( + self, + voice_type: str, + output_prompt_embeds, + output_token_ids, + thinker_prompt_embeds, + prompt_token_ids, + ): + talker_hf_config = self.talker_config + if hasattr(talker_hf_config, "talker_config"): + talker_hf_config = talker_hf_config.talker_config + + # if len(output.outputs[0].token_ids) == 2: + # issue request + prompt_embeds = torch.cat( + [ + thinker_prompt_embeds, + self._get_embed_text_spk_token(voice_type) + self.embed_codec_pad_token, + output_prompt_embeds[:1] + self.embed_codec_bos_token, + ], + dim=0, + ) + + prompt_token_ids_processed = prompt_token_ids + [ + talker_hf_config.tts_codec_pad_token_id, + output_token_ids[0], + ] + input_tokens_len = len(prompt_token_ids_processed) + # the code below is from model runner in Qwen, may need to further discuss later + if input_tokens_len > 2: + prompt_token_ids_processed = [self.talker_config.tts_codec_mask_token_id] * (input_tokens_len - 2) + [ + self.talker_config.tts_codec_pad_token_id, + self.talker_config.tts_codec_start_token_id, + ] + else: + prompt_token_ids_processed = [ + self.talker_config.tts_codec_pad_token_id, + self.talker_config.tts_codec_start_token_id, + ][-input_tokens_len:] + if isinstance(prompt_token_ids_processed, list): + prompt_token_ids_processed = ( + torch.Tensor(prompt_token_ids_processed).to(torch.int64).to(self._module_device(self.talker)) + ) + return prompt_token_ids_processed, prompt_embeds + + def thinker_to_talker_decode_one_step(self, input_ids, input_embeds, **info_dict): + update_dict = {} + # choose step vector in priority order + step_vec = None + q = info_dict.get("thinker_reply_part", None) + if isinstance(q, torch.Tensor) and q.numel() > 0: + step_vec = q[0:1] + new_q = q[1:].detach().to("cpu").contiguous() + update_dict["thinker_reply_part"] = new_q + else: + # B) per-request provided decode vector (optional) + dv = info_dict.get("decode_output_prompt_embeds") if isinstance(info_dict, dict) else None + if isinstance(dv, torch.Tensor) and dv.numel() > 0: + step_vec = dv[0:1] if dv.ndim == 2 else dv.view(1, -1) + elif ( + hasattr(self, "thinker_reply_part") + and isinstance(self.thinker_reply_part, torch.Tensor) + and self.thinker_reply_part.numel() > 0 + ): + # C) fallback shared pool + step_vec = self.thinker_reply_part[0:1] + self.thinker_reply_part = self.thinker_reply_part[1:] + + if isinstance(step_vec, torch.Tensor) and step_vec.numel() > 0: + one_id = input_ids[0:1] + _, one_embed = self._thinker_to_talker_decode_one_step( + output_prompt_embeds=step_vec.to(input_embeds.dtype).to(self._module_device(self.model)), + output_token_ids=one_id, + ) + input_embeds[0] = one_embed[0] + return input_ids[0:1], input_embeds[0:1], update_dict + + def _thinker_to_talker_decode_one_step( + self, + output_prompt_embeds, + output_token_ids, + ): + processed_output_token_embeds = output_prompt_embeds + self.talker.embed_input_ids( + output_token_ids + ) # for decode + return output_token_ids, processed_output_token_embeds + + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, **kwargs: object) -> torch.Tensor | None: + # Handle OmniOutput type + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + + # Use thinker model for logits computation + return self.model.compute_logits(hidden_states) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + # Use thinker model for sampling + return self.model.sample(logits, sampling_metadata) + + def generate_speech(self, text_tokens: torch.Tensor, voice_type: str = "default") -> torch.Tensor: + """ + Generate speech from text tokens using the talker and token2wav models. + This method is kept for backward compatibility and direct speech generation. + + Args: + text_tokens: Text tokens from thinker model + voice_type: Voice type for speech generation + + Returns: + Audio tensor + """ + # Generate codec tokens using talker model + talker_output = self.talker(input_ids=None, positions=None, inputs_embeds=text_tokens) + + # Convert talker output to codec tokens + codec_tokens = self._convert_to_codec_tokens(talker_output) + + # Generate audio using token2wav model + return self._codec_to_audio(codec_tokens, voice_type=voice_type) + + def _convert_to_codec_tokens( + self, talker_output: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + """ + Reference (HF): use the talker's codec head to obtain logits, suppress BOS, + then greedily select the next codec token for the current step. + """ + with torch.inference_mode(): + logits = self.talker.compute_logits(talker_output, None) + if logits is None: + return torch.zeros( + (talker_output.size(0), 0), + dtype=torch.long, + device=talker_output.device, + ) + + # Suppress only codec_bos, consistent with HF generate's + # suppress_tokens behavior + bos_id = None + if hasattr(self, "talker_config") and hasattr(self.talker_config, "tts_codec_start_token_id"): + bos_id = int(getattr(self.talker_config, "tts_codec_start_token_id")) + if bos_id is not None: + logits[..., bos_id] = -1e9 + + # Take the distribution at the last step and select greedily + next_id = self.talker.sample(logits, sampling_metadata).sampled_token_ids + return next_id.to(dtype=torch.long) + + def _init_token2wav_model(self, hf_model_folder): + """Initialize speaker resources if provided; model is constructed in + __init__.""" + if self.token2wav is None or self.token2wav_config is None: + return + device = self._module_device(self.token2wav) + # optional speaker resources + conds = getattr(self.token2wav_config, "conds", None) + ref_mels = getattr(self.token2wav_config, "ref_mels", None) + if isinstance(conds, dict) and isinstance(ref_mels, dict): + self._token2wav_conds = {k: torch.as_tensor(v, device=device) for k, v in conds.items()} + self._token2wav_ref_mels = {k: torch.as_tensor(v, device=device) for k, v in ref_mels.items()} + # legacy: load from directory if provided + model_path = hf_model_folder + if isinstance(model_path, str) and os.path.isdir(model_path): + spk_pt = os.path.join(model_path, "spk_dict.pt") + if os.path.exists(spk_pt): + data = torch.load(spk_pt, map_location=device) + for key, value in data.items(): + self._token2wav_conds[key] = value["cond"].to(device) + self._token2wav_ref_mels[key] = value["ref_mel"].to(device) + else: + # legacy npy inputs + for f in sorted(glob.glob(os.path.join(model_path, "inputs", "*spk_emb.npy"))): + key = os.path.basename(f).split("_")[0].lower() + self._token2wav_conds[key] = torch.as_tensor(np.load(f), device=device) + for f in sorted(glob.glob(os.path.join(model_path, "inputs", "*ref_mel.npy"))): + key = os.path.basename(f).split("_")[0].lower() + self._token2wav_ref_mels[key] = torch.as_tensor(np.load(f), device=device) + + def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default") -> torch.Tensor | None: + if self.token2wav is None: + self._init_token2wav_model() + if self.token2wav is None: + return None + # Normalize voice type + voice = voice_type or "default" + # Resolve cond / ref_mel if provided + cond = None + ref_mel = None + if voice in self._token2wav_conds and voice in self._token2wav_ref_mels: + cond = self._token2wav_conds[voice] + ref_mel = self._token2wav_ref_mels[voice] + # Fallback: create dummy cond/ref_mel if not provided + token2wav_dev = self._module_device(self.token2wav) + if cond is None: + cond = torch.zeros( + (1, self.token2wav_config.dit_config.enc_emb_dim), + device=token2wav_dev, + dtype=torch.float32, + ) + if ref_mel is None: + ref_mel = torch.zeros( + (1, 300, self.token2wav_config.dit_config.mel_dim), + device=token2wav_dev, + dtype=torch.float32, + ) + + # Ensure codec is (1, T) long tensor on correct device + if isinstance(codec_tokens, torch.Tensor): + codec = codec_tokens.to(dtype=torch.long, device=token2wav_dev) + if codec.ndim == 1: + codec = codec.unsqueeze(0) + else: + codec = torch.as_tensor(codec_tokens, dtype=torch.long, device=token2wav_dev).unsqueeze(0) + + # Streaming with chunked process and boundary alignment + # (rely on token2wav.process_chunk) + factor = getattr(self.token2wav.token2wav.factor, "factor", 2) + chunk_size = 48 + mel_dim = getattr( + self.token2wav.token2wav.code2wav_dit_model, + "mel_dim", + self.token2wav_config.dit_config.mel_dim, + ) + total_mel = int(codec.shape[1] * factor) + steps = 10 + + # Prepare initial noise for the whole sequence + y_all = torch.randn((1, total_mel, mel_dim), dtype=ref_mel.dtype, device=token2wav_dev) + + logger.info( + "Currently, we do not use the chunked process, we only use the " + "token2wav.process_chunk for the whole sequence. " + "The stream mode will be implemented in the future." + ) + + chunk_ends = [] + for i in range(codec.shape[1]): + chunk_code_length = i * 2 - 24 + finished = i == (codec.shape[1] - 1) + if (chunk_code_length > 0 and chunk_code_length % chunk_size == 0) or finished: + chunk_ends.append(i) + + # Number of chunks in mel domain + prev_generated = None + wav_chunks: list = [] + + with torch.inference_mode(): + for n, i in enumerate([0]): + finished = i == codec.shape[1] - 1 + _, audio_chunk = self.token2wav.process_chunk( + conditioning=cond, + reference_mel=ref_mel, + codec_all=codec, + y_all=y_all, + i=n, + steps=steps, + prev_generated=prev_generated if prev_generated is not None else [], + finished=True, + ) + prev_generated = audio_chunk + wav_chunks.append(audio_chunk.detach().cpu().numpy()) + + if len(wav_chunks) == 0: + return torch.zeros(0, device=token2wav_dev) + + waveform = np.concatenate(wav_chunks) + return torch.as_tensor(waveform, device=token2wav_dev) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for all components of the omni model.""" + loaded_weights = set() + thinker_weights = [] + talker_weights = [] + token2wav_weights = [] + for k, v in weights: + if k.startswith("thinker."): + thinker_weights.append((k, v)) + elif k.startswith("talker."): + talker_weights.append((k, v)) + elif k.startswith("token2wav."): + token2wav_weights.append((k, v)) + else: + raise ValueError(f"Unknown weight prefix: {k}") + + # Load thinker weights + if self.thinker: + if thinker_weights: + thinker_loaded = self.thinker.load_weights(thinker_weights) + else: + thinker_loaded = set([k for k, v in thinker_weights]) + thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker") + loaded_weights.update(thinker_loaded) + + # Load talker weights + if talker_weights and self.talker is not None: + # Map talker weights to appropriate components + if self.thinker is None: + thinker_embedding_weights = [w for n, w in thinker_weights if n == "thinker.model.embed_tokens.weight"] + if thinker_embedding_weights: + self.thinker_embedding = nn.Embedding( + thinker_embedding_weights[0].shape[0], + thinker_embedding_weights[0].shape[1], + ) + self.thinker_embedding.weight = nn.Parameter( + thinker_embedding_weights[0].to(self._module_device(self.talker)) + ) + talker_loaded = self.talker.load_weights(talker_weights) + talker_loaded = add_prefix_to_loaded_weights(talker_loaded, "talker") + loaded_weights.update(talker_loaded) + loaded_weights.update(self._init_special_tokens_embeddings()) + + # Load token2wav weights (if any) + if token2wav_weights and self.token2wav is not None: + # download weights from huggingface for spk_dict.pt + model_path = self.vllm_config.model_config.model + download_dir = self.vllm_config.load_config.download_dir + if os.path.exists(model_path): + hf_model_folder = model_path + else: + hf_model_folder = download_weights_from_hf_specific( + model_path, + download_dir, + allow_patterns=["*.pt"], + ) + self._init_token2wav_model(hf_model_folder) + t2w_loaded = self.token2wav.load_weights(token2wav_weights, os.path.join(hf_model_folder, "spk_dict.pt")) + t2w_loaded = add_prefix_to_loaded_weights(t2w_loaded, "token2wav") + loaded_weights.update(t2w_loaded) + + return loaded_weights diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py new file mode 100644 index 0000000000000000000000000000000000000000..927bc552573b656fc3af3fd50fa2c00caeb26986 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py @@ -0,0 +1,253 @@ +from collections.abc import Iterable +from functools import cached_property + +import torch +import torch.nn as nn +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import Qwen2_5OmniTalkerConfig +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniAudioEncoder + +# from vllm.attention import AttentionMetadata # unused import +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, +) +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionTransformer +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import Qwen2_5OmniConditionalGenerationMixin + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniTalkerForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin +): + logger = init_logger(__name__) + # Align to thinker-style static mapper for clarity + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # text LM head/body in talker + "talker.codec_head.": "language_model.lm_head.", + "talker.model.": "language_model.model.", + # projection weights + "talker.thinker_to_talker_proj.": "thinker_to_talker_proj.", + # fallback root + "talker.": "", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: Qwen2_5OmniTalkerConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.vllm_config = vllm_config + self.prefix = prefix + self.quant_config = quant_config + + if hasattr(config, "talker_config"): + self.config = config.talker_config + vllm_config.model_config.hf_text_config = vllm_config.model_config.hf_config.talker_config + else: + self.config = config + + self.thinker_to_talker_proj = nn.Linear( + self.config.embedding_size, + self.config.hidden_size, + ) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=getattr(self.config, "text_config", self.config), + architectures=["Qwen2ForCausalLM_old"], + ) + self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors + + # suppress start id + self.suppress_start_id = None + + def init_multi_modal(self, thinker_config): + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=self.quant_config, + prefix=maybe_prefix(self.prefix, "visual"), + ) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().embed_input_ids(input_ids) + + return super().embed_input_ids( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def forward( + self, + input_ids: torch.Tensor = None, + positions: torch.Tensor = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + assert input_ids is not None or inputs_embeds is not None, "input_ids or inputs_embeds must be provided" + # forward_context: ForwardContext = get_forward_context() # unused variable + + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + # for profile_run: + inputs_embeds = self.embed_input_ids(input_ids) + + input_ids = None + + # projection + inputs_embeds = self.thinker_to_talker_proj(inputs_embeds) + + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + return hidden_states + + def bad_word_processor(self, logits: torch.Tensor) -> torch.Tensor: + # suppress token IDs unsupported by token2wav + if self.suppress_start_id and self.suppress_start_id < logits.size(-1): + # skip the end token id. + if hasattr(self.config, "tts_codec_end_token_id"): + end_id = int(getattr(self.config, "tts_codec_end_token_id")) + if self.suppress_start_id == end_id: + logits[..., end_id + 1 : logits.size(-1)] = -1e9 + elif self.suppress_start_id < end_id: + logits[..., self.suppress_start_id : end_id] = -1e9 + logits[..., end_id + 1 : logits.size(-1)] = -1e9 + else: + logits[..., self.suppress_start_id : logits.size(-1)] = -1e9 + else: + raise ValueError("config must have tts_codec_end_token_id attribute") + + if hasattr(self.config, "tts_codec_start_token_id"): + bos_id = int(getattr(self.config, "tts_codec_start_token_id")) + logits[..., bos_id] = -1e9 + return logits + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + logits = self.language_model.compute_logits(hidden_states) + logits = self.bad_word_processor(logits) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["thinker.", "token2wav."], + ) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + # Log load summary + try: + total_bytes = 0 + for name, param in self.named_parameters(): + if param is not None and param.data is not None: + total_bytes += param.data.numel() * param.data.element_size() + device = next(self.parameters()).device + self.logger.info( + "[Model Loaded] name=%s, success=%s, size=%.2f MB, device=%s", + self.__class__.__name__, + True, + total_bytes / (1024**2), + str(device), + ) + except Exception: + pass + multi_model_weights = set() + for name, param in self.visual.named_parameters(): + multi_model_weights.add("visual." + name) + for name, param in self.audio_tower.named_parameters(): + multi_model_weights.add("audio_tower." + name) + loaded.update(multi_model_weights) + return loaded + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds") and "image" not in mm_input_by_modality: + mm_input_by_modality["image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds") and "video" not in mm_input_by_modality: + mm_input_by_modality["video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features") and "audio" not in mm_input_by_modality: + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += audio_embeddings + return multimodal_embeddings + + def set_suppress_start_id(self, start_id: int): + self.suppress_start_id = start_id + self.logger.debug(f"Set suppress start id to {self.suppress_start_id}") diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py new file mode 100644 index 0000000000000000000000000000000000000000..bb2de4f148c600d40ce81031a69daaaad59f8f36 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -0,0 +1,572 @@ +"""Thin Omni wrapper: reuse upstream Qwen2.5-Omni thinker (v0.14) with minimal overrides.""" + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniThinkerConfig, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder, +) +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniAudioFeatureInputs, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, +) +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin as Qwen2_5OmniConditionalGenerationMixinBase, +) +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionTransformer, + Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, + split_list_into_ranges, +) +from vllm.model_executor.models.vision import get_llm_pos_ids_for_vision +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, +) +from vllm.sequence import IntermediateTensors + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None +logger = init_logger(__name__) + + +class Qwen2_5OmniConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixinBase): + def _parse_and_validate_audio_input(self, **kwargs: object) -> Qwen2_5OmniAudioFeatureInputs | None: + input_audio_features = kwargs.pop("input_audio_features", None) + audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) + feature_attention_mask = kwargs.pop("feature_attention_mask", None) + if input_audio_features is None: + return None + if ( + input_audio_features is not None + and isinstance(input_audio_features, torch.Tensor) + and input_audio_features.ndim == 3 + ): + input_audio_features = input_audio_features.reshape(-1, input_audio_features.shape[-1]) + elif input_audio_features is not None and isinstance(input_audio_features, list): + input_audio_features = torch.cat(input_audio_features, dim=-1) + if ( + audio_feature_lengths is not None + and isinstance(audio_feature_lengths, torch.Tensor) + and audio_feature_lengths.ndim == 2 + ): + audio_feature_lengths = audio_feature_lengths.reshape(-1) + elif audio_feature_lengths is not None and isinstance(audio_feature_lengths, list): + audio_feature_lengths = torch.cat(audio_feature_lengths, dim=-1) + if ( + feature_attention_mask is not None + and isinstance(feature_attention_mask, torch.Tensor) + and feature_attention_mask.ndim == 3 + ): + feature_attention_mask = feature_attention_mask.reshape(-1, feature_attention_mask.shape[-1]) + elif feature_attention_mask is not None and isinstance(feature_attention_mask, list): + for i in range(len(feature_attention_mask)): + feature_attention_mask[i] = feature_attention_mask[i].reshape(-1) + return Qwen2_5OmniAudioFeatureInputs( + type="audio_features", + input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask, + ) + + def _parse_and_validate_image_input( + self, + **kwargs: dict[str, Any], + ) -> Qwen2_5_VLImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + if pixel_values is not None and isinstance(pixel_values, torch.Tensor) and pixel_values.ndim == 3: + pixel_values = pixel_values.reshape(-1, pixel_values.shape[-1]) + if image_embeds is not None and isinstance(image_embeds, torch.Tensor) and image_embeds.ndim == 3: + image_embeds = image_embeds.reshape(-1, image_embeds.shape[-1]) + if image_grid_thw is not None and isinstance(image_grid_thw, torch.Tensor) and image_grid_thw.ndim == 3: + image_grid_thw = image_grid_thw.reshape(-1, image_grid_thw.shape[-1]) + if pixel_values is not None: + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, + **kwargs: dict[str, Any], + ) -> Qwen2_5_VLVideoInputs | None: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if ( + pixel_values_videos is not None + and isinstance(pixel_values_videos, torch.Tensor) + and pixel_values_videos.ndim == 3 + ): + pixel_values_videos = pixel_values_videos.reshape(-1, pixel_values_videos.shape[-1]) + if video_grid_thw is not None and isinstance(video_grid_thw, torch.Tensor) and video_grid_thw.ndim == 3: + video_grid_thw = video_grid_thw.reshape(-1, video_grid_thw.shape[-1]) + if video_embeds is not None and isinstance(video_embeds, torch.Tensor) and video_embeds.ndim == 3: + video_embeds = video_embeds.reshape(-1, video_embeds.shape[-1]) + if pixel_values_videos is not None: + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + if not isinstance(video_embeds, torch.Tensor): + raise ValueError(f"Incorrect type of video embeddings. Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + with set_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: list[str] = None, + cached_video_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) + with set_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, + SupportsMRoPE, + Qwen2_5OmniConditionalGenerationMixin, +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + } + ) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "attn.qkv": [ + "attn.q", + "attn.k", + "attn.v", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|vision_start|><|IMAGE|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|VIDEO|><|vision_end|>" + if modality.startswith("audio"): + return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>" + + raise ValueError("Only image, video or audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + thinker_config: Qwen2_5OmniThinkerConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.multimodal_config = multimodal_config + + # force "use_flash_attention_2=True" to audio tower to align + # the results. + if flash_attn is not None: + audio_config = thinker_config.audio_config + audio_config._attn_implementation_autoset = True + audio_config._attn_implementation = "flash_attention_2" + else: + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part." + ) + + if multimodal_config.get_limit_per_prompt("audio"): + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + else: + self.audio_tower = None + + if multimodal_config.get_limit_per_prompt("image") or multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = None + + self.quant_config = quant_config + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=thinker_config.text_config, + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds") and "image" not in mm_input_by_modality: + mm_input_by_modality["image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds") and "video" not in mm_input_by_modality: + mm_input_by_modality["video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features") and "audio" not in mm_input_by_modality: + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_mrope_input_positions( + self, + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + ) -> tuple[torch.Tensor, int]: + """ + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(image_grid_thw) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(video_grid_thw) + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = self.config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, "tokens_per_second", 25) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = get_llm_pos_ids_for_vision( + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) * [audio_token_id]) + audio_start_idx = ( + start_idx if len(audio_llm_pos_ids_list) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend((pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + llm_pos_ids_list[-1].max() + 1 + ).split(1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + + return llm_positions, mrope_position_delta + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor corresponding to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += tuple(video_embeddings) + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += tuple(audio_embeddings) + return multimodal_embeddings + + # TODO (ywang96): support overlapping modality embeddings so that + # `use_audio_in_video` will work on V1. + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().embed_input_ids(input_ids) + + return super().embed_input_ids( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = ["talker.", "token2wav."] + if self.audio_tower is None: + skip_prefixes.extend(["audio_tower."]) + if self.visual is None: + skip_prefixes.extend(["visual."]) + + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + return loaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="merger.", + tower_model=["visual.", "audio_tower."], + ) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py new file mode 100644 index 0000000000000000000000000000000000000000..84ff8b55f0222ae2288b35d37daf49a2089cf77d --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py @@ -0,0 +1,1881 @@ +############################ +# Start Token2Wav # +############################ + +import math +from collections.abc import Iterable + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniBigVGANConfig, + Qwen2_5OmniDiTConfig, + Qwen2_5OmniToken2WavConfig, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniPreTrainedModel + +# Bring in HF base classes, configs and utilities used below +from transformers.utils.logging import get_logger as _hf_get_logger +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import AutoWeightsLoader as _Vllm_AutoWeightsLoader +from vllm.model_executor.models.utils import WeightsMapper as _Vllm_WeightsMapper +from vllm.model_executor.models.utils import init_vllm_registered_model as _vllm_init_vllm_registered_model +from vllm.model_executor.models.utils import maybe_prefix as _vllm_maybe_prefix +from vllm.sequence import IntermediateTensors +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +from vllm_omni.model_executor.models.qwen2_5_omni.audio_length import cap_and_align_mel_length, resolve_max_mel_frames +from vllm_omni.platforms import current_omni_platform + + +# Provide a no-op auto_docstring decorator to satisfy annotations if missing +def auto_docstring(func=None, **_kwargs): + if func is None: + + def wrapper(f): + return f + + return wrapper + return func + + +# HF logger alias +logger = _hf_get_logger(__name__) + + +# Using custom RoPE, will use LlamaRotaryEmbedding next version +class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 # noqa: E501 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, + max_len=seq_length, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + """ + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: bool | None = False, + code_embed_uncond: bool | None = None, + apply_cfg: bool | None = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class Qwen2_5_OmniAdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to + unsqueeze cos[position_ids] and sin[position_ids] so that they can be + properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape + [batch_size, seq_len, head_dim]. Then, if q and k have the shape + [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 + makes cos[position_ids] and sin[position_ids] broadcastable to the + shapes of q and k. Similarly, if q and k have the shape + [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated + using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, prefix: str = ""): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self.is_causal = False + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.dim, + head_size=config.head_dim, + total_num_heads=self.heads, + bias=True, + prefix=f"{prefix}.qkv_proj", + disable_tp=True, + return_bias=False, + ) + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, self.dim), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + qkv = self.qkv_proj(hidden_states) + query, key, value = qkv.split([self.inner_dim, self.inner_dim, self.inner_dim], dim=-1) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, + # will be fixed at next release + cos, sin = position_embeddings + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, + # mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude + of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper + by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://huggingface.co/papers/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + # TODO: When torch.kaiser_window supports NPU, remove the device="cpu" argument + if current_omni_platform.is_npu(): + kaiser_window = torch.kaiser_window( + kernel_size, beta=beta, periodic=False, dtype=torch.float32, device="cpu" + ).to("npu") + elif current_omni_platform.is_xpu(): + kaiser_window = torch.kaiser_window( + kernel_size, beta=beta, periodic=False, dtype=torch.float32, device="cpu" + ).to("xpu") + else: + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +def replication_pad_1d(hidden_states: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor: + """ + Manual replicate padding to avoid replication_pad1d kernel limits on NPU. + TODO: remove when F.pad supports replicate mode on NPU. + """ + # NOTE: a immature implementation for running in NPU. Need to discuss. + if pad_left == 0 and pad_right == 0: + return hidden_states + + segments = [] + if pad_left > 0: + left = hidden_states[..., :1].expand(*hidden_states.shape[:-1], pad_left) + segments.append(left) + + segments.append(hidden_states) + + if pad_right > 0: + right = hidden_states[..., -1:].expand(*hidden_states.shape[:-1], pad_right) + segments.append(right) + + return torch.cat(segments, dim=-1) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + if current_omni_platform.is_npu(): + # TODO: When F.pad supports replicate mode on NPU, remove this branch + input_dtype = hidden_states.dtype + # F.pad in NPU doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d(hidden_states.to(self.filter.dtype), self.pad, self.pad) + filter_convert_dtype = self.filter.to(hidden_states.dtype) + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, + filter_convert_dtype.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(input_dtype) + else: + hidden_states_dtype = hidden_states.dtype + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to(self.filter.dtype) + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, + self.filter.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(hidden_states_dtype) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + if current_omni_platform.is_npu(): + input_dtype = hidden_states.dtype + # F.pad in NPU doesn't support BF16 when mode is replicate. + # To ensure the accuracy, manually pad the input tensor. + hidden_states = replication_pad_1d(hidden_states.to(self.filter.dtype), self.pad_left, self.pad_right) + filter_on_device = self.filter.to(device=hidden_states.device, dtype=hidden_states.dtype) + out = F.conv1d( + hidden_states, + filter_on_device.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(input_dtype) + else: + hidden_states_dtype = hidden_states.dtype + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate").to( + self.filter.dtype + ) + out = F.conv1d( + hidden_states, + self.filter.expand(channels, -1, -1), + stride=self.stride, + groups=channels, + ).to(hidden_states_dtype) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise TypeError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=self._get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=self._get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=self._get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, hidden_states): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + residual = hidden_states + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram + as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniBigVGANConfig + + def __init__(self, config: Qwen2_5OmniBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock( + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + dilation, + ) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), + 1, + 7, + 1, + padding=3, + bias=False, + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp( + (2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, + -max_value, + max_value, + ) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor( + min_db_level / 20.0 * np.log(10), + device=amplitude.device, + dtype=amplitude.dtype, + ) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu() + + +class RungeKutta4ODESolver: + def __init__(self, function, initial_value): + self.function = function + self.initial_value = initial_value + + self._one_third = 1 / 3 + self._two_thirds = 2 / 3 + + def _rk4_step( + self, + function, + time_start, + time_step, + time_end, + value_start, + function_value_start=None, + ): + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k2 = function( + time_start + time_step * self._one_third, + value_start + time_step * k1 * self._one_third, + ) + k3 = function( + time_start + time_step * self._two_thirds, + value_start + time_step * (k2 - k1 * self._one_third), + ) + k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) + return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 + + def _compute_step(self, function, time_start, time_step, time_end, value_start): + function_value_start = function(time_start, value_start) + return ( + self._rk4_step( + function, + time_start, + time_step, + time_end, + value_start, + function_value_start=function_value_start, + ), + function_value_start, + ) + + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + if time_point == time_start: + return value_start + if time_point == time_end: + return value_end + weight = (time_point - time_start) / (time_end - time_start) + return value_start + weight * (value_end - value_start) + + def integrate(self, time_points): + solution = torch.empty( + len(time_points), + *self.initial_value.shape, + dtype=self.initial_value.dtype, + device=self.initial_value.device, + ) + solution[0] = self.initial_value + + current_index = 1 + current_value = self.initial_value + for time_start, time_end in zip(time_points[:-1], time_points[1:]): + time_step = time_end - time_start + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + next_value = current_value + delta_value + + while current_index < len(time_points) and time_end >= time_points[current_index]: + solution[current_index] = self._linear_interpolation( + time_start, + time_end, + current_value, + next_value, + time_points[current_index], + ) + current_index += 1 + + current_value = next_value + + return solution + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as + input and predict mel spectrogram. + """ +) +class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + max_mel_frames: int | None = None, + ): + max_mel_frames = resolve_max_mel_frames(max_mel_frames, default=30000) + target_code_len, target_duration = cap_and_align_mel_length( + code_len=int(quantized_code.shape[1]), + repeats=int(self.repeats), + max_mel_frames=max_mel_frames, + ) + if int(quantized_code.shape[1]) != target_code_len: + quantized_code = quantized_code[:, :target_code_len] + + initial_state = torch.randn( + [1, target_duration, self.mel_dim], + dtype=reference_mel_spectrogram.dtype, + device=quantized_code.device, + ) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, target_duration, 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, + 1, + num_steps, + device=quantized_code.device, + dtype=conditioning_vector.dtype, + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + def fast_block_sample( + self, + conditioning_vector: torch.Tensor, + reference_mel_spectrogram: torch.Tensor, + quantized_code: torch.Tensor, + y0: torch.Tensor, + num_steps: int = 10, + guidance_scale: float = 0.5, + sway_coefficient: float | None = -1.0, + ) -> torch.Tensor: + """ + Block-wise ODE sampling starting from provided initial state y0. + + Args: + conditioning_vector: (B, enc_emb_dim) + reference_mel_spectrogram: (B, T_ref, mel_dim) + quantized_code: (B, T_code) + y0: (B, T_target, mel_dim) initial state for ODE + Returns: + mel: (B, mel_dim, T_target) + """ + initial_state = y0.to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, initial_state.shape[1], 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, + 1, + num_steps, + device=quantized_code.device, + dtype=conditioning_vector.dtype, + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self-attn + (".qkv_proj", ".to_q", "q"), + (".qkv_proj", ".to_k", "k"), + (".qkv_proj", ".to_v", "v"), + ] + + params_dict = dict(self.named_parameters()) + + loaded_params = set[str]() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech + tokens as input and predict mel spectrogram and a BigVGAN vocoder take + mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniToken2WavConfig + base_model_prefix = "model" + _no_split_modules = [ + "Qwen2_5OmniToken2WavDiTModel", + "Qwen2_5OmniToken2WavBigVGANModel", + ] + + def __init__(self, config: Qwen2_5OmniToken2WavConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen2_5OmniToken2WavModel must inference with fp32, but " + "flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen2_5OmniToken2WavModel will " + "fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( + config.dit_config, attn_implementation=attn_impl + ) + self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + # Streaming-related parameters aligned with Qwen2Code2wav + self.factor = self.code2wav_dit_model.repeats # 50Hz=2, 200Hz=4 + # default bs_mel depends on factor + self.bs_mel = 24 if self.factor == 2 else 32 + self.bs_codec = self.bs_mel // self.factor + self.past_cache_size = self.bs_mel * self.factor + self.future_cache_size = self.bs_mel * 1 + self.batched_chunk = 3 + self.chunk_size = self.bs_mel * self.batched_chunk + self.future_size = 20 if self.factor == 2 else 13 + + # codec embedding size for masking EOS out-of-range + try: + self.codec_embed_size = self.code2wav_dit_model.text_embed.codec_embed.weight.size(0) + except Exception: + self.codec_embed_size = -1 + + # vocoder hop length inferred from upsample rates + try: + ups = self.code2wav_bigvgan_model.config.upsample_rates + hop = 1 + for r in ups: + hop *= int(r) + self.vocoder_hop = int(hop) + except Exception: + # fallback to commonly used value + self.vocoder_hop = 240 + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + max_mel_frames: int | None = None, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.code2wav_dit_model.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + max_mel_frames=max_mel_frames, + ).to(self.code2wav_bigvgan_model.dtype) + + waveform = self.code2wav_bigvgan_model(mel_spectrogram).to(self.dtype) + + return waveform + + # ============== Chunked processing helpers (compat with qwen2_code2wav_dit) ============== # noqa: E501 + @torch.inference_mode() + def process_chunk_dit_batch( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + code: torch.Tensor, + y0: torch.Tensor, + steps: int, + ) -> torch.Tensor: + """ + Block-wise DiT: generate mel from initial state y0 for the given code slice. + """ + # prevent codec out-of-range (eos) + if self.codec_embed_size > 0: + code = code.clone() + code[code >= self.codec_embed_size] = 0 + mel = self.code2wav_dit_model.fast_block_sample( + conditioning_vector=conditioning, + reference_mel_spectrogram=reference_mel, + quantized_code=code, + y0=y0, + num_steps=steps, + ) + return mel.to(self.code2wav_bigvgan_model.dtype) + + @torch.inference_mode() + def process_chunk_bigvgan_batch(self, mel_batch: torch.Tensor) -> torch.Tensor: + """Vocoder batch: mel -> waveform.""" + return self.code2wav_bigvgan_model(mel_batch) + + @torch.inference_mode() + def process_little_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: torch.Tensor, + finished: bool = False, + ) -> tuple[torch.Tensor | None, torch.Tensor]: + """Streaming per small chunk: returns (mel_or_None, audio_slice).""" + start_index = max(i * self.chunk_size - self.past_cache_size, 0) + end_index = min( + (i + 1) * self.chunk_size + self.future_cache_size, + codec_all.shape[1] * self.factor, + ) + + y0 = y_all[:, start_index:end_index].reshape(1, -1, self.code2wav_dit_model.mel_dim).contiguous() + codec = codec_all[:, start_index // self.factor : end_index // self.factor].reshape(1, -1).contiguous() + + # generate mel for current window (B, mel_dim, T) + generated = self.process_chunk_dit_batch( + conditioning=conditioning, + reference_mel=reference_mel, + code=codec, + y0=y0, + steps=steps, + ) + + # splice and vocode with 50Hz-style rules + return self._process_chunk_for_50hz( + i=i, + start_index=start_index, + end_index=end_index, + finished=finished, + prev_generated=prev_generated, + generated=generated, + ) + + @torch.inference_mode() + def process_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: torch.Tensor | list[torch.Tensor], + finished: bool = False, + ) -> tuple[torch.Tensor | list[torch.Tensor], torch.Tensor]: + """High-level chunk API aligning to qwen2_code2wav_dit signature.""" + if not isinstance(prev_generated, torch.Tensor): + prev_generated = prev_generated[0] if len(prev_generated) > 0 else None + _mel, audio = self.process_little_chunk( + conditioning=conditioning, + reference_mel=reference_mel, + codec_all=codec_all, + y_all=y_all, + i=i, + steps=steps, + prev_generated=prev_generated, + finished=finished, + ) + return _mel if _mel is not None else prev_generated, audio + + @torch.inference_mode() + def _process_chunk_for_50hz( + self, + i: int, + start_index: int, + end_index: int, + finished: bool, + prev_generated: torch.Tensor | None, + generated: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Align mel and audio boundaries for 50Hz-like streaming. + + Shapes: + - generated: (B, mel_dim, T_window) + - prev_generated: (B, mel_dim, T_prev) + Returns: + - mel_chunk: (B, mel_dim, T_chunk) + - audio_slice: (T_audio_chunk,) + """ + # Normalize dtype + generated = generated.to(torch.float32) + if i == 0: + mel = generated[:, :, : self.chunk_size] + elif finished: + mel_trim = generated[:, :, self.past_cache_size :] + mel = torch.cat([prev_generated[:, :, -self.future_size * 2 :], mel_trim], dim=2) + else: + if start_index == 0: + mel_trim = generated[:, :, i * self.chunk_size : -self.future_cache_size] + else: + mel_trim = generated[:, :, self.past_cache_size : -self.future_cache_size] + mel = torch.cat([prev_generated[:, :, -self.future_size * 2 :], mel_trim], dim=2) + + audio = self.code2wav_bigvgan_model(mel) + if i == 0: + audio_output = audio[: -self.future_size * self.vocoder_hop] + elif finished: + audio_output = audio[self.future_size * self.vocoder_hop :] + else: + audio_output = audio[self.future_size * self.vocoder_hop : -self.future_size * self.vocoder_hop] + return mel, audio_output + + +# ================= vLLM-style wrapper for Token2Wav ================= + + +class Qwen2_5OmniToken2WavForConditionalGenerationVLLM(nn.Module, SupportsPP): + logger = init_logger(__name__) + + # Map HF weights -> vLLM module names + hf_to_vllm_mapper = _Vllm_WeightsMapper( + orig_to_new_prefix={ + # HF root is 'model.' + "model.": "token2wav_model.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + # Expect hf_config to be Token2Wav config + self.config = vllm_config.model_config.hf_config + + # Initialize underlying HF Token2Wav model via registry + self.token2wav = _vllm_init_vllm_registered_model( + vllm_config=vllm_config, + prefix=_vllm_maybe_prefix(prefix, "token2wav_model"), + hf_config=self.config, + architectures=["Qwen2_5OmniToken2WavDiTModel"], + ) + + # Provide placeholder to align with vLLM runner expectations + def _empty_intermediate_tensors(): + return None + + self.make_empty_intermediate_tensors = _empty_intermediate_tensors + + def get_language_model(self) -> torch.nn.Module: + return self.token2wav + + @property + def sampler(self): + # Token2Wav does not use sampler; return vLLM default for API parity + return Sampler() + + def forward( + self, + code: torch.Tensor, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + num_steps: int = 10, + guidance_scale: float = 0.5, + sway_coefficient: float = -1.0, + intermediate_tensors: IntermediateTensors | None = None, + **kwargs, + ) -> torch.Tensor: + # Delegate to HF token2wav model + return self.token2wav( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + # Token2Wav outputs waveform; logits are not applicable + return hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + return None + + def load_weights_without_buffers(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = _Vllm_AutoWeightsLoader(self) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + # Log load summary + try: + total_bytes = 0 + for _, param in self.named_parameters(): + if param is not None and param.data is not None: + total_bytes += param.data.numel() * param.data.element_size() + device = next(self.parameters()).device + self.logger.info( + "[Model Loaded] name=%s, success=%s, size=%.2f MB, device=%s", + self.__class__.__name__, + True, + total_bytes / (1024**2), + str(device), + ) + except Exception: + pass + return loaded + + def find_all_registers(self): + """ + Find all registered buffers in a PyTorch model. + + Args: + Returns: + dict: Dictionary with buffer names as keys and their properties as values + """ + registers = {} + + # Get all named buffers + for name, buf in self.named_buffers(): + if name in self.state_dict(): + registers[name] = {"name": name, "buffer": buf} + return registers + + # remove buffers from the weights and reload them after loading weights + def remove_buffers_from_weights(self, weights: Iterable[tuple[str, torch.Tensor]], buffers: dict): + weights_to_load = [] + for key, value in weights: + if key in buffers: + buffers[key]["buffer"] = value + continue + weights_to_load.append((key, value)) + return weights_to_load + + def reload_buffers_to_model(self, buffers: dict): + """ + reload stored buffers from weights to model + """ + loaded_buffers = set() + for name, buf_val in self.named_buffers(): + if name in buffers: + buf_val.copy_(buffers[name]["buffer"]) + loaded_buffers.add(name) + return loaded_buffers + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]], spk_dict_path: str) -> set[str]: + buffers = self.find_all_registers() + weights_to_load = self.remove_buffers_from_weights(weights, buffers) + loaded = self.load_weights_without_buffers(weights_to_load) + loaded_buffers = self.reload_buffers_to_model(buffers) + # merge loaded and loaded_buffers + loaded.update(loaded_buffers) + self.spk_dict = torch.load(spk_dict_path) + return loaded + + # ============== Optional chunked helpers for API parity ============== + @torch.inference_mode() + def process_chunk_dit_batch( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + code: torch.Tensor, + y0: torch.Tensor, + steps: int, + ) -> torch.Tensor: + return self.token2wav( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=steps, + ) + + @torch.inference_mode() + def process_chunk_bigvgan_batch(self, mel_batch: torch.Tensor) -> torch.Tensor | None: + # BigVGAN is not part of this wrapper; return None for parity. + return None + + @torch.inference_mode() + def process_little_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: torch.Tensor, + finished: bool = False, + ) -> tuple[torch.Tensor | None, torch.Tensor]: + mel = self.token2wav( + code=codec_all, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=steps, + ) + return None, mel + + @torch.inference_mode() + def process_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: torch.Tensor | list[torch.Tensor], + finished: bool = False, + ) -> tuple[torch.Tensor | list[torch.Tensor], torch.Tensor]: + _mel, out = self.process_little_chunk( + conditioning=conditioning, + reference_mel=reference_mel, + codec_all=codec_all, + y_all=y_all, + i=i, + steps=steps, + prev_generated=(prev_generated if isinstance(prev_generated, torch.Tensor) else None), + finished=finished, + ) + return _mel if _mel is not None else prev_generated, out diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py new file mode 100644 index 0000000000000000000000000000000000000000..e04010196a8165dbd34cd7a63ffeed942188999f --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py @@ -0,0 +1,456 @@ +from collections.abc import Iterable + +import torch +from torch import nn +from transformers import Qwen2Config +from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backend import AttentionType +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +logger = init_logger(__name__) + + +class Qwen2MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + head_dim: int | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: tuple | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_pos_emb = get_rope( + head_size=self.head_dim, + max_position=max_position, + is_neox_style=True, + rope_parameters={ + "base": self.rope_theta, + **rope_scaling, + }, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_pos_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2DecoderLayer(nn.Module): + def __init__( + self, + config: Qwen2Config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + head_dim=getattr(config, "head_dim", None), + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + } +) +class Qwen2Model(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer, + ): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # TODO (@robertgshaw2): see if this can be moved out + if cache_config.sliding_window is not None and hasattr(config, "max_window_layers"): + raise ValueError( + "Sliding window for some but all layers is not " + "supported. This model uses sliding window " + f"but `max_window_layers` = {config.max_window_layers} is less than " + f"`num_hidden_layers` = {config.num_hidden_layers}. Please open an issue " + "to discuss this feature." + ) + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + getattr(config, "embedding_size", config.hidden_size), + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to Qwen2DecoderLayer + decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer : self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name)): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm_omni/model_executor/models/qwen3_omni/__init__.py b/vllm_omni/model_executor/models/qwen3_omni/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a469704f6140664987a575adf5cf1fa3bfe832dd --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_omni/__init__.py @@ -0,0 +1,3 @@ +from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration + +__all__ = ["Qwen3OmniMoeForConditionalGeneration"] diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9332363136d6f0aa7dc07c4c64d9efaf1caf602e --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.qwen3_moe import ( + Qwen3MoeDecoderLayer, + Qwen3MoeMLP, + Qwen3MoeModel, # as _BaseQwen3MoeModel, +) +from vllm.model_executor.models.qwen3_moe import ( + Qwen3MoeForCausalLM as _BaseQwen3MoeForCausalLM, +) +from vllm.model_executor.models.utils import ( + PPMissingLayer, + maybe_prefix, +) + +logger = init_logger(__name__) + + +# Individual expert MoE block using Qwen3MoeMLP instead of FusedMoE +class Qwen3OmniMoeSparseMoeBlock(nn.Module): + """Sparse MoE block using individual Qwen3MoeMLP experts instead of FusedMoE.""" + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + config = vllm_config.model_config.hf_text_config + quant_config = vllm_config.quant_config + + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + + # Create individual expert MLPs + self.experts = nn.ModuleList( + [ + Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.experts.{i}", + ) + for i in range(self.num_experts) + ] + ) + + # Router for expert selection + from vllm.model_executor.layers.linear import ReplicatedLinear + + self.gate = ReplicatedLinear( + config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate" + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass using individual experts.""" + # Handle 3D inputs (batch, seq_len, hidden_size) by reshaping to 2D + orig_shape = hidden_states.shape + if hidden_states.dim() == 3: + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + elif hidden_states.dim() == 2: + num_tokens, hidden_dim = hidden_states.shape + elif hidden_states.dim() == 1: + hidden_states = hidden_states.unsqueeze(0) + num_tokens, hidden_dim = hidden_states.shape + else: + raise ValueError( + f"Qwen3OmniMoeSparseMoeBlock only supports 1D, 2D, or 3D inputs, got {hidden_states.dim()}D" + ) + + is_input_1d = len(orig_shape) == 1 + hidden_states = hidden_states.view(-1, hidden_dim) + + # Get router logits and select experts (matching transformers) + router_logits, _ = self.gate(hidden_states) + selected_experts, routing_weights = self._route_tokens(router_logits) + + # Forward through individual experts + final_hidden_states = self._forward_experts(hidden_states, selected_experts, routing_weights) + + # Reshape back to original shape + if is_input_1d: + return final_hidden_states.squeeze(0) + elif len(orig_shape) == 3: + # Reshape back to 3D (batch, seq_len, hidden_dim) + return final_hidden_states.view(orig_shape) + else: + return final_hidden_states + + def _route_tokens(self, router_logits: torch.Tensor): + """Route tokens to experts using top-k selection (matching transformers).""" + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(router_logits.dtype) + return selected_experts, routing_weights + + def _forward_experts( + self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor + ): + """Forward through individual experts (matching transformers implementation).""" + final_hidden_states = torch.zeros_like(hidden_states) + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) + current_hidden_states = self.experts[expert_idx](current_state) * routing_weights[top_x, idx, None] + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + return final_hidden_states + + +class Qwen3MoeForCausalLM(_BaseQwen3MoeForCausalLM): + """Thin wrapper to swap in the patched `Qwen3MoeModel`.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Don't call super().__init__() to avoid duplicate layer registration. + nn.Module.__init__(self) + config = vllm_config.model_config.hf_text_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head") + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + # Set MoE hyperparameters for individual experts + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3MoeDecoderLayer) + if isinstance(layer.mlp, FusedMoE): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp) + + if example_layer is None: + raise RuntimeError("No Qwen3OmniMoe layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..37753ee47842aef105096f2727946590d5f6caf9 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -0,0 +1,1145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +"""Inference-only Qwen3-Omni-Moe unified model (thinker + talker + code2wav).""" + +from collections.abc import Iterable +from functools import cached_property + +import torch +import torch.nn as nn +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeCode2WavConfig, + Qwen3OmniMoeConfig, + Qwen3OmniMoeTalkerConfig, + Qwen3OmniMoeThinkerConfig, +) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP +from vllm.model_executor.models.qwen3_omni_moe_thinker import ( + Qwen3OmniMoeConditionalGenerationMixin, + Qwen3OmniMoeThinkerDummyInputsBuilder, + Qwen3OmniMoeThinkerMultiModalProcessor, + Qwen3OmniMoeThinkerProcessingInfo, +) +from vllm.model_executor.models.utils import init_vllm_registered_model, maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFeatureSpec +from vllm.sequence import IntermediateTensors +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights, safe_tensor_reshape +from vllm_omni.platforms import current_omni_platform + +# Special token IDs for Qwen3 Omni MoE +# Reference: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json + +# Audio tokens (thinker vocabulary, for marking audio boundaries) +AUDIO_START_TOKEN_ID = 151669 # <|audio_start|> (audio_bos_token) +AUDIO_END_TOKEN_ID = 151670 # <|audio_end|> (audio_eos_token) +AUDIO_PAD_TOKEN_ID = 151675 # <|audio_pad|> + +# TTS text tokens (thinker vocabulary, for text-to-speech control) +TTS_PAD_TOKEN_ID = 151671 # <tts_pad> +TTS_BOS_TOKEN_ID = 151672 # <tts_text_bos> +TTS_EOS_TOKEN_ID = 151673 # <tts_text_eod> (end of dialogue) +TTS_BOS_SINGLE_TOKEN_ID = 151674 # <tts_text_bos_single> + +# Talker codec tokens (talker vocabulary, used for RVQ code generation) +TALKER_CODEC_PAD_TOKEN_ID = 4196 # Padding token +TALKER_CODEC_BOS_TOKEN_ID = 4197 # Beginning of speech +TALKER_CODEC_EOS_TOKEN_ID = 4198 # End of speech +TALKER_CODEC_NOTHINK_ID = 4203 # No-think mode +TALKER_CODEC_THINK_BOS_ID = 4204 # Think mode start +TALKER_CODEC_THINK_EOS_ID = 4205 # Think mode end + +logger = init_logger(__name__) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3OmniMoeThinkerMultiModalProcessor, + info=Qwen3OmniMoeThinkerProcessingInfo, + dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder, +) +class Qwen3OmniMoeForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, Qwen3OmniMoeConditionalGenerationMixin, CustomProcessMixin, SupportsMRoPE +): + """ + Unified Qwen3 Omni MoE model combining thinker, talker, and code2wav. + + Architecture: + - Thinker: Multimodal understanding (text + audio + video) → text generation + - Talker: Text embeddings → RVQ codec codes + - Code2Wav: RVQ codes → audio waveform + + Usage: + Set `model_stage` in vllm_config to one of: "thinker", "talker", "code2wav" + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.have_multimodal_outputs = True + self.has_preprocess = False + self.has_postprocess = False + config: Qwen3OmniMoeConfig = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + + # Keep vllm_config for later submodule init + self.vllm_config = vllm_config + self.config = config + + # Initialize thinker components + thinker_config: Qwen3OmniMoeThinkerConfig = config.thinker_config + self.thinker_config = thinker_config + self.multimodal_config = multimodal_config + + # Initialize talker components + talker_config: Qwen3OmniMoeTalkerConfig = config.talker_config + self.talker_config = talker_config + + # Initialize code2wav components + code2wav_config: Qwen3OmniMoeCode2WavConfig = config.code2wav_config + self.code2wav_config = code2wav_config + + # Determine model stage + self.model_stage = vllm_config.model_config.model_stage + + if self.model_stage == "thinker": + # Initialize thinker model (multimodal processing + text generation) + # Create a new vllm_config with thinker_config as the hf_config + thinker_vllm_config = vllm_config.with_hf_config( + thinker_config, architectures=["Qwen3OmniMoeThinkerForConditionalGeneration"] + ) + self.thinker = init_vllm_registered_model( + vllm_config=thinker_vllm_config, + prefix=maybe_prefix(prefix, "thinker"), + hf_config=thinker_config, + architectures=["Qwen3OmniMoeThinkerForConditionalGeneration"], + ) + self.model = self.thinker + self.talker = None + self.code2wav = None + self.tts_tokens = torch.tensor( + [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], + device=self._module_device(self.thinker), + dtype=torch.long, + ) + elif self.model_stage == "talker": + self.has_preprocess = True + self.has_postprocess = True + self.set_custom_preprocess(self.talker_preprocess) + self.set_custom_postprocess(self.talker_postprocess) + self.thinker = None + # Initialize talker model (text embeddings → codec codes) + # Create a new vllm_config with talker_config as the hf_config + # This ensures the talker uses its own text_config (smaller vocab_size) + talker_vllm_config = vllm_config.with_hf_config( + talker_config, architectures=["Qwen3OmniMoeTalkerForConditionalGeneration"] + ) + self.talker = init_vllm_registered_model( + vllm_config=talker_vllm_config, + prefix=maybe_prefix(prefix, "talker"), + hf_config=talker_config, + architectures=["Qwen3OmniMoeTalkerForConditionalGeneration"], + ) + self.talker.init_multi_modal(thinker_config) + self.model = self.talker + self.code2wav = None + + # for CI: Initialize special tokens embeddings early to avoid AttributeError when loading dummy weights + self._init_special_tokens_embeddings() + self.requires_raw_input_tokens = True + + elif self.model_stage == "code2wav": + self.thinker = None + self.talker = None + # Initialize code2wav (codec codes → audio waveform) + # Create a new vllm_config with code2wav_config as the hf_config + code2wav_vllm_config = vllm_config.with_hf_config(code2wav_config, architectures=["Qwen3OmniMoeCode2Wav"]) + self.code2wav = init_vllm_registered_model( + vllm_config=code2wav_vllm_config, + prefix=maybe_prefix(prefix, "code2wav"), + hf_config=code2wav_config, + architectures=["Qwen3OmniMoeCode2Wav"], + ) + self.model = self.code2wav + self.requires_raw_input_tokens = True + else: + raise ValueError( + f"Invalid model_stage: {self.model_stage}. Must be one of: 'thinker', 'talker', 'code2wav'" + ) + + # Set up intermediate tensors + self.make_empty_intermediate_tensors = ( + self.thinker.make_empty_intermediate_tensors if self.model_stage == "thinker" else lambda: None + ) + + # ==================== Device utilities ==================== + + @staticmethod + def _module_device(module: nn.Module) -> torch.device: + """Get the device of a module.""" + try: + return next(module.parameters()).device + except StopIteration: + # No parameters; fall back to buffers or cpu + for _, buf in module.named_buffers(recurse=True): + return buf.device + return torch.device("cpu") + + @cached_property + def sampler(self): + """Get sampler from active model.""" + if hasattr(self.model, "sampler"): + return self.model.sampler + return Sampler() + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + is_multimodal=None, + ) -> torch.Tensor: + if self.model_stage == "code2wav": + return torch.zeros_like(input_ids).reshape(-1, 1).repeat(1, self.vllm_config.model_config.get_hidden_size()) + return self.model.embed_input_ids( + input_ids=input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal + ) + + def embed_multimodal(self, **kwargs): + """Delegate to active model for multimodal processing.""" + return self.model.embed_multimodal(**kwargs) + + # ==================== Forward Pass ==================== + def _get_talker_suppressed_tokens(self): + return [ + i + for i in range( + self.config.talker_config.text_config.vocab_size - 1024, + self.config.talker_config.text_config.vocab_size, + ) + if i != self.config.talker_config.codec_eos_token_id + ] + + def get_mrope_input_positions( + self, + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec] | None = None, + **kwargs: object, + ) -> tuple[torch.Tensor, int]: + if self.model_stage == "thinker": + if mm_features is None: + msg = "Qwen3 Omni thinker get_mrope_input_positions requires mm_features" + raise ValueError(msg) + return self.thinker.get_mrope_input_positions(input_tokens, mm_features) + return MRotaryEmbedding.get_input_positions_tensor(input_tokens, **kwargs) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + generate_audio: bool = True, + voice_type: str = "ethan", + codec: torch.Tensor | None = None, + sampling_metadata: SamplingMetadata | None = None, + logits_index: int | None = None, + additional_information: dict[str, object] | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors | OmniOutput: + """ + Unified forward pass for all model stages. + + Workflow: + 1) Thinker: multimodal understanding → text hidden states + 2) Talker -> Code Predictor: text embeddings → codec codes (layer 0 + code_predictor:residual layers) + 3) Code2wav: 8-layer RVQ codes → audio waveform + + Returns: + OmniOutput with text_hidden_states and optional audio + """ + + # ========== Stage 1: Thinker ========== + if self.model_stage == "thinker": + thinker_dev = self._module_device(self.thinker) + if current_omni_platform.is_npu(): + # Normalize to batched inputs if needed + _added_batch_dim = False + if input_ids is not None and input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + _added_batch_dim = True + if positions is not None and positions.ndim == 1: + positions = positions.unsqueeze(0) + _added_batch_dim = True + if inputs_embeds is not None and inputs_embeds.ndim == 2: + inputs_embeds = inputs_embeds.unsqueeze(0) + _added_batch_dim = True + + # Handle None input_ids + if input_ids is None: + input_ids = torch.zeros( + inputs_embeds.shape[1], + dtype=torch.long, + device=thinker_dev, + ).unsqueeze(0) + _added_batch_dim = True + + # Move to thinker device + if input_ids is not None and input_ids.device != thinker_dev: + input_ids = input_ids.to(thinker_dev) + if positions is not None and positions.device != thinker_dev: + positions = positions.to(thinker_dev) + if inputs_embeds is not None and inputs_embeds.device != thinker_dev: + inputs_embeds = inputs_embeds.to(thinker_dev) + + # Run thinker forward + # If talker expects a specific intermediate layer, capture it here + accept_layer = getattr(self.talker_config, "accept_hidden_layer", None) + capture_kwargs = {} + if accept_layer is not None: + capture_kwargs = { + "capture_layer_indices": [0, int(accept_layer)], + "return_hidden_states": True, + } + if current_omni_platform.is_npu(): + # TODO: remove this hack when NPU supports batched inputs properly + thinker_input_ids = input_ids[0] if input_ids is not None and _added_batch_dim else input_ids + thinker_inputs_embeds = ( + inputs_embeds[0] if inputs_embeds is not None and _added_batch_dim else inputs_embeds + ) + else: + thinker_input_ids = input_ids + thinker_inputs_embeds = inputs_embeds + + # Run thinker + text_hidden_states, captured_layer_dict = self.thinker( + input_ids=thinker_input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=thinker_inputs_embeds, + **capture_kwargs, + **kwargs, + ) + return text_hidden_states, captured_layer_dict + + # ========== Stage 2.1: Talker ========== + elif self.model_stage == "talker": + if input_ids is None: + # special case for profile run + input_ids = torch.zeros(inputs_embeds.shape[0], dtype=torch.long, device=inputs_embeds.device) + + # Ensure we have base embeddings when only ids are provided + if inputs_embeds is None and input_ids is not None: + inputs_embeds = self.talker.embed_input_ids(input_ids) + + # TODO(Peiqi): temporal hack here to support voice_type. + if not hasattr(self, "voice_type"): + self.voice_type = voice_type + + # Run talker forward + with torch.inference_mode(): + talker_hidden = self.talker.forward( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return talker_hidden + + # ========== Stage 3: Code2Wav ========== + elif self.model_stage == "code2wav": + # Extract codec codes from input + codes = [] + if input_ids.shape[0] % 16 == 0: + codes.append(input_ids.reshape(1, 16, -1)) + else: + logger.warning( + ( + "Input_ids length: %s is not divisible by 16, padding " + "with zeros. This should only happen in warm up." + ), + input_ids.shape[0], + ) + input_ids_flatten = input_ids.reshape(-1) + input_ids_flatten = torch.cat( + [ + input_ids_flatten, + torch.zeros(16 - input_ids.shape[0] % 16, dtype=torch.long, device=input_ids.device), + ] + ) + codes.append(input_ids_flatten.reshape(1, 16, -1)) + + # Generate audio from codec codes + audio_tensors = [] + for code in codes: + audio_tensor = self.generate_audio(code, voice_type) + audio_tensors.append(audio_tensor) + if len(audio_tensors) > 1: + logger.warning( + "Batched input for code2wav is not supported yet, only the first audio tensor will be returned" + ) + + return audio_tensors + + # Fallback (shouldn't reach here) + return OmniOutput( + text_hidden_states=torch.zeros( + [inputs_embeds.shape[0], self.talker.config.hidden_size], + dtype=torch.bfloat16, + ).to(self._module_device(self.model)), + multimodal_outputs=None, + ) + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -> OmniOutput: + """ + Make an OmniOutput object from model outputs. + Args: + model_outputs: Model outputs + """ + if isinstance(model_outputs, OmniOutput): + return model_outputs + + if self.model_stage == "thinker": + text_hidden_states, captured_layer_dict = model_outputs + # Compute thinker-side TTS token embeddings for BOS/EOS/PAD and expose via multimodal outputs. + # These will later be projected into talker text space by the talker stage. + multimodal_outputs = captured_layer_dict if captured_layer_dict is not None else {} + try: + thinker_tts_embeds = self.thinker.embed_input_ids(self.tts_tokens) # [1,3,thinker_hidden] + if ( + isinstance(thinker_tts_embeds, torch.Tensor) + and thinker_tts_embeds.ndim == 3 + and thinker_tts_embeds.shape[1] == 3 + ): + bos_eos_pad = thinker_tts_embeds.to(text_hidden_states.device).chunk(3, dim=1) # 3 * [1,1,H] + multimodal_outputs["tts_bos_embed"] = [bos_eos_pad[0]] + multimodal_outputs["tts_eos_embed"] = [bos_eos_pad[1]] + multimodal_outputs["tts_pad_embed"] = [bos_eos_pad[2]] + except Exception: + # Best-effort; absence will be handled by talker with fallbacks + pass + + # Return text-only output (with multimodal sidecar) + return OmniOutput( + text_hidden_states=(text_hidden_states.reshape(-1, text_hidden_states.shape[-1])), + multimodal_outputs=multimodal_outputs, + ) + elif self.model_stage == "talker": + talker_hidden = model_outputs + # merge the code_predictor_codes from the info_dict list into a single tensor + multimodal_outputs: dict = None + # Here is the only place to use runtime_additional_information. After MTP in the + # preprocess function, the code_predictor_codes are stored in the info_dict list. + # We need to merge the tensors from different requests into a single tensor. + # In the future, we may allow user to custom an aggregated function. + info_dicts = kwargs.get("runtime_additional_information") + code_predictor_codes = [info.get("code_predictor_codes") for info in info_dicts] + multimodal_outputs = {"code_predictor_codes": torch.cat(code_predictor_codes, dim=0)} + span_len = multimodal_outputs["code_predictor_codes"].shape[0] + talker_hidden = talker_hidden[:span_len] + return OmniOutput(text_hidden_states=talker_hidden, multimodal_outputs=multimodal_outputs) + elif self.model_stage == "code2wav": + audio_tensors = model_outputs + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": audio_tensors[0].reshape(1, -1)}, + ) + + return model_outputs + + # ==================== Audio Generation ==================== + + def generate_audio(self, code: torch.Tensor, voice_type: str) -> torch.Tensor: + """ + Generate audio waveform from codec codes. + + Args: + code: [8, T] - 8-layer RVQ codec codes + voice_type: Voice type (not used in Qwen3, kept for compatibility) + + Returns: + audio_tensor: [1, waveform_len] - Audio waveform + """ + code2wav_dev = self._module_device(self.code2wav) + + # Convert to tensor if needed + if isinstance(code, torch.Tensor): + talker_codes = code.to(dtype=torch.long, device=code2wav_dev) + else: + talker_codes = torch.as_tensor(code, dtype=torch.long, device=code2wav_dev) + + # Ensure shape is [batch=1, 8, T] + if talker_codes.ndim == 2: + # [8, T] → [1, 8, T] + talker_codes = talker_codes.unsqueeze(0) + elif talker_codes.ndim == 1: + # [T] → assume single layer, expand to 16 layers + talker_codes = talker_codes.unsqueeze(0).unsqueeze(0) + talker_codes = talker_codes.expand(1, 16, -1) + + if self.vllm_config.model_config.async_chunk: + audio_tensor = self.code2wav.chunked_decode_streaming( + talker_codes, + chunk_size=25, + left_context_size=25, + ) + else: + # Use chunked decode for memory efficiency + audio_tensor = self.code2wav.chunked_decode( + talker_codes, + chunk_size=300, + left_context_size=25, + ) + + return audio_tensor + + # ==================== Thinker-Talker Projection ==================== + + def _load_talker_embedding(self) -> torch.nn.Embedding: + """Load talker embedding layer.""" + return self.talker.language_model.model.codec_embedding + + def _init_special_tokens_embeddings(self) -> set[str]: + """ + Initialize special token embeddings for thinker-talker projection. + + Following Transformers implementation: + - TTS tokens (BOS/EOS/PAD) come from thinker's embedding, projected to talker space + - Codec tokens (BOS/EOS/PAD/NOTHINK/THINK_*) come from talker's embedding + - Speaker tokens are also from talker's embedding + + Note on projections: + - text_projection: Used here for text token embeddings (thinker → talker dimension) + - hidden_projection: Used at runtime for multimodal hidden states (audio/image/video) + from thinker's last layer, not needed for special token initialization + """ + self.talker_embedding = self._load_talker_embedding() + + # Get configuration + talker_hf_config = self.talker_config + if hasattr(talker_hf_config, "talker_config"): + talker_hf_config = talker_hf_config.talker_config + + codec_special_tokens = torch.tensor( + [ + [ + talker_hf_config.codec_nothink_id, + talker_hf_config.codec_think_bos_id, + talker_hf_config.codec_think_eos_id, + talker_hf_config.codec_pad_id, + talker_hf_config.codec_bos_id, + talker_hf_config.codec_eos_token_id, + ] + ], + device=self._module_device(self.talker), + dtype=torch.long, + ) + codec_embeds = self.talker_embedding(codec_special_tokens) # [1, 6, talker_hidden] + ( + self.embed_codec_nothink_token, + self.embed_codec_think_bos_token, + self.embed_codec_think_eos_token, + self.embed_codec_pad_token, + self.embed_codec_bos_token, + self.embed_codec_eos_token, + ) = codec_embeds.chunk(6, dim=1) + + # Speaker token IDs (for voice selection) + # In Qwen3, speaker_id mapping is in talker_config.speaker_id + if hasattr(talker_hf_config, "speaker_id") and talker_hf_config.speaker_id: + self.tts_text_spk_token_ids = talker_hf_config.speaker_id + else: + # Default to audio_start_token_id if no speaker mapping + self.tts_text_spk_token_ids = { + "default": talker_hf_config.audio_start_token_id, + "Ethan": talker_hf_config.audio_start_token_id, + "prefix_caching": talker_hf_config.audio_start_token_id, + } + + self.default_tts_text_spk_type = list(self.tts_text_spk_token_ids.keys())[0] + + return set(["thinker_embedding.weight", "talker_embedding.weight"]) + + def _get_text_spk_token_id(self, voice_type: str) -> int: + """Get speaker token ID for voice type.""" + if voice_type not in self.tts_text_spk_token_ids: + return self.tts_text_spk_token_ids[self.default_tts_text_spk_type] + return self.tts_text_spk_token_ids[voice_type] + + def talker_postprocess(self, hidden_states: torch.Tensor, **info_dict: object): + """ + Postprocess the talker hidden states. + """ + update_dict = {} + update_dict["last_talker_hidden"] = hidden_states[-1, :].detach().to("cpu").contiguous() + return update_dict + + def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict): + """ + Preprocess talker embeds. Noted that we set the MTP here. + """ + # Ensure we have base embeddings when only ids are provided + if input_embeds is None and input_ids is not None: + input_embeds = self.talker.embed_input_ids(input_ids) + + span_len = input_ids.shape[0] + if span_len > 1: + # prefill + input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, **info_dict) + code_predictor_codes = torch.zeros( + (input_embeds.shape[0], self.talker.num_code_groups), + device=self._module_device(self.talker), + dtype=torch.long, + ) + update_dict["code_predictor_codes"] = code_predictor_codes + else: + last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode( + input_ids, input_embeds, **info_dict + ) + update_dict["mtp_inputs"] = last_talker_hidden, text_step + + return input_ids, input_embeds, update_dict + + def talker_mtp( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor, + last_talker_hidden: torch.Tensor, + text_step: torch.Tensor, + ): + # TODO(Peiqi): not support intermediate_tensors now + input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1)) + inputs_embeds = safe_tensor_reshape(input_embeds, (-1, self.talker_config.text_config.hidden_size)) + text_step = safe_tensor_reshape(text_step, (-1, self.talker_config.text_config.hidden_size)) + last_talker_hidden = safe_tensor_reshape( + last_talker_hidden, (-1, 1, self.talker_config.text_config.hidden_size) + ) + # for profiling + if inputs_embeds.shape[-1] == 2048: + inputs_embeds = self.text_projection(inputs_embeds) + code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward( + input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden + ) + inputs_embeds = summed_embeddings.clone() + inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size) + return inputs_embeds, code_predictor_codes.squeeze(-1) + + def _get_tts_embed(self, thinker_embed, tts_bos_thinker, tts_eos_thinker, tts_pad_thinker): + """Project thinker-side TTS embeddings into talker text space.""" + module_device = self._module_device(self.talker) + + def _ensure_1x1(x: torch.Tensor) -> torch.Tensor: + if x.ndim == 3: + return x[0, -1:, :] + if x.ndim == 2: + return x[-1] + return x.view(1, 1, -1) + + def _proj_from_thinker(x_opt: torch.Tensor | None) -> torch.Tensor: + if isinstance(x_opt, torch.Tensor) and x_opt.numel() > 0: + xin = _ensure_1x1(x_opt).to(module_device) + else: + xin = torch.zeros( + (1, thinker_embed.shape[-1]), + device=module_device, + dtype=thinker_embed.dtype, + ) + return self.talker.text_projection(xin).to(module_device) + + self.tts_bos_embed = _proj_from_thinker(tts_bos_thinker) + self.tts_eos_embed = _proj_from_thinker(tts_eos_thinker) + self.tts_pad_embed = _proj_from_thinker(tts_pad_thinker) + return self.tts_bos_embed, self.tts_eos_embed, self.tts_pad_embed + + def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict): + # Containers to return per-request updates (e.g., code_predictor_hidden_per_request) + update_dict: dict[str, dict] = {} + # TODO(Peiqi): add voice_type support + voice_type = self.voice_type + + # Read thinker outputs for prefill + thinker_sequence_embeds = info_dict.get("thinker_embeddings").to( + device=self._module_device(self.talker), dtype=torch.bfloat16 + ) # Tensor [P,H] + thinker_hidden_states = info_dict.get("thinker_hidden_states").to( + device=self._module_device(self.talker), dtype=torch.bfloat16 + ) # Tensor [K,H] + thinker_sequences = ( + info_dict.get("thinker_sequences") + if info_dict.get("thinker_sequences") is None + else torch.as_tensor(info_dict.get("thinker_sequences"), device=self._module_device(self.talker)) + ) + thinker_chatml_ids = ( + info_dict.get("thinker_input_ids") + if info_dict.get("thinker_input_ids") is None + else torch.as_tensor(info_dict.get("thinker_input_ids"), device=self._module_device(self.talker)) + ) + + tts_bos_thinker = info_dict.get("tts_bos_embed").to( + device=self._module_device(self.talker), dtype=torch.bfloat16 + ) + tts_eos_thinker = info_dict.get("tts_eos_embed").to( + device=self._module_device(self.talker), dtype=torch.bfloat16 + ) + tts_pad_thinker = info_dict.get("tts_pad_embed").to( + device=self._module_device(self.talker), dtype=torch.bfloat16 + ) + + if thinker_sequence_embeds is None or thinker_hidden_states is None: + raise ValueError( + "additional_information_by_req_id must include " + "'thinker_embeddings' and 'thinker_hidden_states' for talker prefill." + ) + + # Normalize to tensors + if not isinstance(thinker_sequence_embeds, torch.Tensor): + thinker_sequence_embeds = torch.as_tensor(thinker_sequence_embeds, device=self._module_device(self.talker)) + if not isinstance(thinker_hidden_states, torch.Tensor): + thinker_hidden_states = torch.as_tensor(thinker_hidden_states, device=self._module_device(self.talker)) + + if isinstance(thinker_chatml_ids, torch.Tensor) or isinstance(thinker_chatml_ids, list): + ids_chatml = ( + thinker_chatml_ids + if isinstance(thinker_chatml_ids, torch.Tensor) + else torch.as_tensor(thinker_chatml_ids, device=self._module_device(self.talker)) + ) + if ids_chatml.ndim == 1: + ids_chatml = ids_chatml.unsqueeze(0) + else: + # Fallback: create dummy ids if not provided + ids_chatml = torch.zeros( + (1, thinker_sequence_embeds.shape[1]), + dtype=torch.long, + device=self._module_device(self.talker), + ) + thinker_sequences = ids_chatml + + speaker_id = self._get_text_spk_token_id(voice_type) + req_input_ids, req_embeds, trailing_text_hidden = self._thinker_to_talker_prefill( + thinker_embed=thinker_sequence_embeds.to(self._module_device(self.talker)), + thinker_hidden=thinker_hidden_states.to(self._module_device(self.talker)), + multimodal_mask=None, + input_ids=ids_chatml.to(self._module_device(self.talker)), + thinker_result_ids=thinker_sequences.to(self._module_device(self.talker)), + speaker_id=speaker_id, + tts_bos_thinker=tts_bos_thinker, + tts_eos_thinker=tts_eos_thinker, + tts_pad_thinker=tts_pad_thinker, + ) + + # Queue trailing_text_hidden for decode (drop first for next steps), + try: + if isinstance(trailing_text_hidden, torch.Tensor) and trailing_text_hidden.numel() > 0: + if trailing_text_hidden.ndim == 2: + rem_tail = trailing_text_hidden + elif trailing_text_hidden.ndim == 1: + rem_tail = torch.zeros( + 0, + trailing_text_hidden.shape[0], + dtype=trailing_text_hidden.dtype, + device=trailing_text_hidden.device, + ) + else: + # compatible with old shape [1,S,D] + rem_tail = trailing_text_hidden.squeeze(0) + if rem_tail.shape[0] > 0: + update_dict["trailing_text_hidden"] = rem_tail.detach().to("cpu").contiguous() + # Also persist projected tts_pad for decode fallback if needed + if isinstance(tts_pad_thinker, torch.Tensor): + pad_in = tts_pad_thinker + if pad_in.ndim == 2: + pad_in = pad_in.unsqueeze(0) + if pad_in.ndim == 1: + pad_in = pad_in.view(1, 1, -1) + pad_proj = self.talker.text_projection(pad_in.to(self._module_device(self.talker))) + update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous() + except Exception: + pass + + return req_input_ids, req_embeds, update_dict + + def _thinker_to_talker_prefill( + self, + thinker_embed: torch.Tensor, + thinker_hidden: torch.Tensor, + multimodal_mask: torch.Tensor | None, + input_ids: torch.Tensor, + thinker_result_ids: torch.Tensor, + speaker_id, + tts_bos_thinker: torch.Tensor | None = None, + tts_eos_thinker: torch.Tensor | None = None, + tts_pad_thinker: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """ + Project thinker outputs to talker inputs during prefill stage. + + Returns: + (input_ids, input_embeds) for talker + """ + im_start_indexes = torch.cat( + ( + torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(), + torch.tensor([thinker_result_ids.shape[-1]], device=input_ids.device, dtype=input_ids.dtype), + ), + dim=-1, + ) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here. + multimodal_mask = ( + (thinker_result_ids == self.thinker_config.audio_token_id) | + (thinker_result_ids == self.thinker_config.image_token_id) | + (thinker_result_ids == self.thinker_config.video_token_id) + ).to(input_ids.device) # [t] # fmt: skip + + tts_bos_embed, tts_eos_embed, tts_pad_embed = self._get_tts_embed( + thinker_embed, tts_bos_thinker, tts_eos_thinker, tts_pad_thinker + ) + + talker_input_embeds = [] # [1 t d] + talker_input_ids = [] + trailing_text_hidden_all: torch.Tensor | None = None + # For every chatml parts + for i in range(len(im_start_indexes) - 1): + im_start_index = im_start_indexes[i].item() + segment_end_index = im_start_indexes[i + 1].item() + role_token = input_ids[0][im_start_index + 1] + # Talker should ignore thinker system prompt + if (role_token == self.config.system_token_id).item(): + continue + # Talker takes word embeddings for tokens and hidden state from `accept_hidden_layer` for multimodal inputs + elif (role_token == self.config.user_token_id).item(): + talker_user_part = self._get_talker_user_parts( + im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed + ) + talker_input_embeds.append(talker_user_part) + talker_input_ids.append(thinker_result_ids[im_start_index:segment_end_index]) + # Take assistant output (for now) + elif (role_token == self.config.assistant_token_id).item() and i == len(im_start_indexes) - 2: + talker_assistant_embeds, talker_assistant_ids, trailing_text_hidden = self._get_talker_assistant_parts( + im_start_index, + segment_end_index, + speaker_id, + thinker_embed, + tts_pad_embed, + tts_bos_embed, + tts_eos_embed, + ) + talker_input_embeds.append(talker_assistant_embeds) + talker_input_ids.append(talker_assistant_ids) + # capture trailing text hidden for decode steps + try: + if isinstance(trailing_text_hidden, torch.Tensor): + trailing_text_hidden_all = trailing_text_hidden + except Exception: + pass + # History assistant output (ignore for now) + elif (role_token == self.config.assistant_token_id).item() and i != len(im_start_indexes) - 2: + continue + else: + raise AssertionError("Expect role id after <|im_start|> (assistant, user, system)") + talker_input_embed = torch.cat([embed.to(input_ids.device) for embed in talker_input_embeds], dim=0) + talker_input_id = torch.cat([embed.to(input_ids.device) for embed in talker_input_ids], dim=0) + + return talker_input_id, talker_input_embed, trailing_text_hidden_all + + def _thinker_decode_to_talker_decode( + self, + info_dict: dict, + device: torch.device, + update_dict, + ): + """ + Project thinker outputs to talker inputs during prefill stage. + Returns: + (input_ids, input_embeds) for talker + """ + thinker_embed = info_dict.get("thinker_embeddings", None) + if thinker_embed is None: + if info_dict.get("finished_flag"): + return self.tts_pad_embed.to(device) + update_dict["finished_flag"] = True + return self.tts_eos_embed.to(device) + + thinker_embed = thinker_embed.to(device) + return self.talker.text_projection(thinker_embed).to(device) + + def talker_preprocess_decode(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict): + update_dict: dict[str, dict] = {} + last_talker_hidden = None + text_step = None + try: + if self.vllm_config.model_config.async_chunk: + text_step = self._thinker_decode_to_talker_decode(info_dict, input_ids.device, update_dict) + else: + q_tail = info_dict.get("trailing_text_hidden", None) + if isinstance(q_tail, torch.Tensor) and q_tail.numel() > 0: + use_vec = q_tail[0:1, :] + new_q_tail = ( + q_tail[1:, :].detach().to("cpu").contiguous() + if q_tail.shape[0] > 1 + else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype) + ) + text_step = use_vec.to(input_embeds.device, dtype=input_embeds.dtype) + update_dict["trailing_text_hidden"] = new_q_tail + else: + text_step = self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype) + + last_talker_hidden_tensor = info_dict.get("last_talker_hidden") + if last_talker_hidden_tensor is not None: + last_talker_hidden = last_talker_hidden_tensor.to(input_embeds.device, dtype=input_embeds.dtype) + last_talker_hidden = last_talker_hidden.reshape(*last_talker_hidden.shape[-2:]) # [1, hidden_size] + else: + last_talker_hidden = torch.zeros( + (1, self.talker_config.text_config.hidden_size), + device=input_embeds.device, + dtype=input_embeds.dtype, + ) + except Exception as e: + logger.error(f"Error in decode: {e}") + + return last_talker_hidden, text_step, update_dict + + def _get_talker_user_parts(self, im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed): + user_talker_part = torch.empty( + (segment_end_index - im_start_index, self.config.talker_config.text_config.hidden_size), + device=thinker_hidden.device, + dtype=torch.bfloat16, + ) + + user_mm_mask = multimodal_mask[im_start_index:segment_end_index] + # Multimodal data exists + if user_mm_mask.any(): + user_thinker_hidden_mm = thinker_hidden[im_start_index:segment_end_index][user_mm_mask] + mm_hidden = self.talker.hidden_projection(user_thinker_hidden_mm).to(thinker_hidden.device) + user_talker_part[user_mm_mask] = mm_hidden + user_thinker_embed = thinker_embed[im_start_index:segment_end_index][~user_mm_mask] + user_text_hidden = self.talker.text_projection(user_thinker_embed).to(thinker_hidden.device) + user_talker_part[~user_mm_mask] = user_text_hidden + return user_talker_part + + def _get_talker_assistant_parts( + self, im_start_index, segment_end_index, speaker_id, thinker_embed, tts_pad_embed, tts_bos_embed, tts_eos_embed + ): + assistant_hidden = self.talker.text_projection(thinker_embed[im_start_index:segment_end_index]).to( + tts_pad_embed.device + ) # [t, d] + + # [3 tokens] + [4 pad] + [1 BOS] + [1 first text] = 9 tokens + assistant_text_hidden = torch.cat( + ( + assistant_hidden[:3], + tts_pad_embed.expand(4, -1), + tts_bos_embed, + assistant_hidden[3:4] + if assistant_hidden.shape[0] > 3 + else torch.zeros( + (1, assistant_hidden.shape[1]), + device=assistant_hidden.device, + dtype=assistant_hidden.dtype, + ), # First text + ), + dim=0, + ) + codec_special_tokens = torch.tensor( + [ + self.config.talker_config.codec_nothink_id, + self.config.talker_config.codec_think_bos_id, + self.config.talker_config.codec_think_eos_id, + speaker_id, + self.config.talker_config.codec_pad_id, + self.config.talker_config.codec_bos_id, + ], + device=tts_pad_embed.device, + dtype=torch.long, + ) + embed_input_ids = self.talker.embed_input_ids(codec_special_tokens).to( + device=tts_pad_embed.device, dtype=torch.bfloat16 + ) + assistant_codec_hidden = torch.cat( + ( + torch.zeros( + (3, self.config.talker_config.text_config.hidden_size), + device=tts_pad_embed.device, + dtype=torch.bfloat16, + ), + embed_input_ids, + ), + dim=0, + ) + + if assistant_hidden.shape[0] > 4: + trailing_text_hidden = torch.cat( + (assistant_hidden[4:], tts_eos_embed), + dim=0, + ) + else: + trailing_text_hidden = torch.zeros( + tts_eos_embed.shape, device=tts_eos_embed.device, dtype=tts_eos_embed.dtype + ) + + input_embeds = assistant_text_hidden + assistant_codec_hidden + input_ids = torch.full( + (assistant_text_hidden.shape[0],), + fill_value=self.config.tts_pad_token_id, + dtype=torch.long, + device=assistant_text_hidden.device, + ) + return input_embeds, input_ids, trailing_text_hidden + + def _talker_to_code_predictor( + self, + talker_hidden_states: torch.Tensor | None, + layer0_token_ids: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Project talker outputs to code predictor inputs. + + Returns: + (input_ids, input_embeds) for code predictor. + """ + predictor = getattr(self, "code_predictor", None) + device = ( + self._module_device(predictor) + if predictor is not None + else ( + talker_hidden_states.device + if isinstance(talker_hidden_states, torch.Tensor) + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + ) + + if not isinstance(talker_hidden_states, torch.Tensor): + raise ValueError("Talker hidden states must be provided for the code predictor stage.") + + inputs_embeds = talker_hidden_states.to(device=device, dtype=torch.bfloat16) + if inputs_embeds.ndim == 2: + inputs_embeds = inputs_embeds.unsqueeze(0) + + if not isinstance(layer0_token_ids, torch.Tensor): + raise ValueError("Layer-0 codec token ids must accompany talker hidden states.") + input_ids = layer0_token_ids.to(device=device, dtype=torch.long) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + return input_ids, inputs_embeds + + # ==================== Logits and Sampling ==================== + + def _warn_talker_sampling_temperature(self, sampling_metadata: SamplingMetadata): + warning_parts = [] + if sampling_metadata.temperature is None: + warning_parts.append( + "Temperature is set to None, as all requests are greedy. " + "This is equivalent to setting temperature to 0.0." + "Please consider setting a higher temperature i.e. 0.4." + ) + else: + warning_parts.append( + "Temperature is set to: " + f"{sampling_metadata.temperature}, where temperature as 0.0 may " + "cause repetitive output. Please consider setting a higher " + "temperature i.e. 0.4." + ) + warning_parts.append( + "This warning will be shown only once, for the first request where " + "temperature is 0.0. Later requests will not show this warning but " + "still be affected by the temperature." + ) + warning_info = "\n".join(warning_parts) + logger.warning_once(warning_info) + + def compute_logits( + self, + hidden_states: torch.Tensor | OmniOutput, + sampling_metadata: SamplingMetadata = None, + ) -> torch.Tensor | None: + """Compute logits from hidden states.""" + # Handle OmniOutput type + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + + if ( + getattr(self, "model_stage", None) == "talker" + and sampling_metadata is not None + and (sampling_metadata.temperature is None or (sampling_metadata.temperature <= 0).any()) + ): + self._warn_talker_sampling_temperature(sampling_metadata) + + # Use active model for logits computation + logits = self.model.compute_logits(hidden_states) # V, d + # Talker: suppress tokens by setting their probability to ~1e-9 (finite very small), + # implemented by assigning their logits to log(1e-9). + + if getattr(self, "model_stage", None) == "talker" and isinstance(logits, torch.Tensor): + # suppress tokens by setting their probability to ~1e-9 (finite very small) + suppressed_tokens = self._get_talker_suppressed_tokens() + try: + logits_cpu = logits.cpu() + logits_cpu[:, suppressed_tokens] = -1e9 + logits = logits_cpu.to(logits.device) + except Exception as e: + print(f"Error in logits suppression: {e}") + print(f"logits.shape: {logits.shape}") + print(f"suppressed_tokens: {suppressed_tokens}") + raise e + logits[:, suppressed_tokens] = -1e9 + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + """Sample from logits.""" + return self.model.sample(logits, sampling_metadata) + + # ==================== Weight Loading ==================== + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for all components of the omni model.""" + loaded_weights = set() + thinker_weights = [] + talker_weights = [] + code2wav_weights = [] + + # Separate weights by component + for k, v in weights: + if k.startswith("thinker."): + thinker_weights.append((k, v)) + elif k.startswith("talker."): + talker_weights.append((k, v)) + elif k.startswith("code2wav."): + code2wav_weights.append((k, v)) + else: + logger.warning(f"Unknown weight prefix: {k}") + # Load thinker weights + if self.thinker and thinker_weights: + thinker_loaded = self.thinker.load_weights(thinker_weights) + thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker") + loaded_weights.update(thinker_loaded) + + # Load talker weights + if self.talker and talker_weights: + talker_loaded = self.talker.load_weights(talker_weights) + talker_loaded = add_prefix_to_loaded_weights(talker_loaded, "talker") + loaded_weights.update(talker_loaded) + loaded_weights.update(self._init_special_tokens_embeddings()) + + # Load code2wav weights + if self.code2wav and code2wav_weights: + code2wav_loaded = self.code2wav.load_weights(code2wav_weights) + code2wav_loaded = add_prefix_to_loaded_weights(code2wav_loaded, "code2wav") + loaded_weights.update(code2wav_loaded) + + # Log summary + logger.info( + "Loaded %d weights for Qwen3OmniMoe (stage=%s)", + len(loaded_weights), + self.model_stage, + ) + + return loaded_weights diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py new file mode 100644 index 0000000000000000000000000000000000000000..7adb6f96f89662bc883604f7ea36fceb5cffebcc --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +"""Inference-only Qwen3-Omni-Moe Code2Wav model.""" + +from __future__ import annotations + +from collections.abc import Iterable + +import numpy as np +import torch +import torch.nn as nn +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeCode2WavConfig, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeCausalConvNet, + Qwen3OmniMoeCausalTransConvNet, + Qwen3OmniMoeCode2WavDecoderBlock, + Qwen3OmniMoeCode2WavTransformerModel, + Qwen3OmniMoeConvNeXtBlock, + SnakeBeta, +) +from vllm.config import VllmConfig # type: ignore +from vllm.logger import init_logger # type: ignore +from vllm.model_executor.models.utils import ( # type: ignore + AutoWeightsLoader, + WeightsMapper, +) + +logger = init_logger(__name__) + + +class Qwen3OmniMoeCode2Wav(nn.Module): + """ + Qwen3 Omni MoE Code2Wav - Converts num_quantizers-layer RVQ codec codes to audio waveform. + + Architecture: + 1. Code Embedding: Embed and average num_quantizers RVQ layers + 2. Pre-Transformer: Add temporal context via sliding-window attention + 3. Upsampling: Progressive upsampling with ConvNeXt blocks + 4. Decoder: Multi-stage upsampling + residual units → waveform + + Input: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes + Output: [batch, 1, waveform_len] - Audio waveform [-1, 1] + + Total upsampling factor: ~1280x + Example: 100 codec frames → 128,000 audio samples (8 seconds at 16kHz) + """ + + input_modalities = "audio" + + # Weight mapper + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "code2wav.pre_transformer.": "pre_transformer.", + "code2wav.code_embedding.": "code_embedding.", + "code2wav.upsample.": "upsample.", + "code2wav.decoder.": "decoder.", + "code2wav.": "", + } + ) + + def __init__( + self, + *, + vllm_config: VllmConfig | None = None, + prefix: str = "", + ): + super().__init__() + + self.config: Qwen3OmniMoeCode2WavConfig = vllm_config.model_config.hf_config + + # Calculate total upsampling factor + self.total_upsample = np.prod(self.config.upsample_rates + self.config.upsampling_ratios) + + # Pre-transformer + self.pre_transformer = Qwen3OmniMoeCode2WavTransformerModel._from_config(self.config) + + # Code embedding: Single embedding table for all RVQ layers + self.code_embedding = nn.Embedding( + self.config.codebook_size * self.config.num_quantizers, self.config.hidden_size + ) + + # Offset for each RVQ layer (layer 0: 0-1023, layer 1: 1024-2047, etc.) + self.register_buffer( + "code_offset", + torch.arange(self.config.num_quantizers).view(1, -1, 1) * self.config.codebook_size, + persistent=False, + ) + + # Upsampling blocks (e.g., 2x, 2x) + upsample = [] + for factor in self.config.upsampling_ratios: + upsample.append( + nn.ModuleList( + [ + Qwen3OmniMoeCausalTransConvNet( + self.config.hidden_size, self.config.hidden_size, factor, factor + ), + Qwen3OmniMoeConvNeXtBlock(self.config.hidden_size), + ] + ) + ) + self.upsample = nn.ModuleList(upsample) + + # Decoder: Initial projection + progressive upsampling blocks + decoder = [Qwen3OmniMoeCausalConvNet(self.config.hidden_size, self.config.decoder_dim, kernel_size=7)] + + # Add decoder blocks (each upsamples and reduces channels) + for i in range(len(self.config.upsample_rates)): + decoder.append(Qwen3OmniMoeCode2WavDecoderBlock(self.config, i)) + + # Final projection to waveform + output_dim = self.config.decoder_dim // 2 ** len(self.config.upsample_rates) + decoder += [ + SnakeBeta(output_dim), + Qwen3OmniMoeCausalConvNet(output_dim, 1, kernel_size=7), + ] + self.decoder = nn.ModuleList(decoder) + + def forward(self, codes: torch.Tensor) -> torch.Tensor: + """ + Convert num_quantizers-layer RVQ codes to audio waveform. + + Args: + codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codec codes + + Returns: + waveform: [batch, 1, waveform_len] - Audio waveform clipped to [-1, 1] + """ + if codes.shape[1] != self.config.num_quantizers: + raise ValueError(f"Expected {self.config.num_quantizers} layers of codes, got {codes.shape[1]}") + + # Stage 1: Code Embedding + # Add offset to separate layer vocabularies, then embed and average + hidden = self.code_embedding(codes + self.code_offset).mean(1) + # Shape: [batch, seq_len, hidden_size] + + # Stage 2: Pre-Transformer (add temporal context) + hidden = self.pre_transformer(inputs_embeds=hidden).last_hidden_state + # Shape: [batch, seq_len, hidden_size] + + # Stage 3: Upsampling + hidden = hidden.permute(0, 2, 1) # [batch, hidden_size, seq_len] + for blocks in self.upsample: + for block in blocks: + hidden = block(hidden) + # Shape: [batch, hidden_size, seq_len * upsample_factor] + + # Stage 4: Decoder (progressive upsampling to waveform) + wav = hidden + for block in self.decoder: + wav = block(wav) + # Shape: [batch, 1, waveform_len] + + # Clamp to valid audio range + return wav.clamp(min=-1.0, max=1.0) + + def chunked_decode( + self, + codes: torch.Tensor, + chunk_size: int = 300, + left_context_size: int = 25, + ) -> torch.Tensor: + """ + Decode long sequences in chunks to avoid OOM. + + Uses overlapping chunks with left context to avoid boundary artifacts. + + Args: + codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes + chunk_size: Number of codec frames per chunk + left_context_size: Number of overlapping frames for context + + Returns: + waveform: [batch, 1, waveform_len] - Complete waveform + """ + wavs = [] + start_index = 0 + + while start_index < codes.shape[-1]: + end_index = min(start_index + chunk_size, codes.shape[-1]) + context_size = left_context_size if start_index >= left_context_size else start_index + + # Extract chunk with left context + codes_chunk = codes[..., start_index - context_size : end_index] + + # Decode chunk + wav_chunk = self(codes_chunk) + + # Remove context from output (context_size * total_upsample samples) + wavs.append(wav_chunk[..., context_size * self.total_upsample :]) + + start_index = end_index + + return torch.cat(wavs, dim=-1) + + def chunked_decode_streaming( + self, + codes: torch.Tensor, + chunk_size: int = 25, + left_context_size: int = 25, + ) -> torch.Tensor: + """ + Decode long sequences in chunks to avoid OOM. + + Uses overlapping chunks with left context to avoid boundary artifacts. + + Args: + codes: [batch, num_quantizers, seq_len] - num_quantizers-layer RVQ codes + chunk_size: Number of codec frames per chunk + left_context_size: Number of overlapping frames for context + + Returns: + waveform: [batch, 1, waveform_len] - Complete waveform + """ + wavs = [] + end_index = codes.shape[-1] + # TODO: need to optimize algorithms, current only support + # chunk_size = left_context_size = 25 + if end_index <= chunk_size: + context_size = 0 + else: + context_size = left_context_size + # Decode chunk + wav_chunk = self(codes) + # Remove context from output (context_size * total_upsample samples) + wavs.append(wav_chunk[..., context_size * self.total_upsample :]) + return torch.cat(wavs, dim=-1) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights from HuggingFace checkpoint.""" + loader = AutoWeightsLoader( + self, + skip_prefixes=["thinker.", "talker."], # Already loaded above + ) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + # Log load summary + try: + total_bytes = 0 + for name, param in self.named_parameters(): + if param is not None and param.data is not None: + total_bytes += param.data.numel() * param.data.element_size() + device = next(self.parameters()).device + logger.info( + "[Model Loaded] name=%s, success=%s, size=%.2f MB, device=%s", + self.__class__.__name__, + True, + total_bytes / (1024**2), + str(device), + ) + except Exception: + logger.error("Error logging model load summary") + + return loaded diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7402890abe65654f2b8a45c6567467eb822026 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py @@ -0,0 +1,589 @@ +"""Qwen3-Omni Code Predictor with MTP (Multi-Token Prediction) support. + +This module implements the code predictor component for Qwen3-Omni talker models. + +The code predictor generates residual RVQ (Residual Vector Quantization) codes +autoregressively, predicting layers 1 to N based on layer-0 codes from the talker. +""" + +from collections import namedtuple +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Cache, PretrainedConfig +from transformers.generation.logits_process import ( + LogitsProcessorList, + TopKLogitsWarper, + TopPLogitsWarper, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.utils.torch_utils import direct_register_custom_op + +logger = init_logger(__name__) + +# ============================================================================ +# Code Predictor Attention Layer +# ============================================================================ + + +class Qwen3OmniCodePredictorAttention(nn.Module): + """Multi-head self-attention for code predictor with vLLM optimization.""" + + def __init__( + self, + config, + layer_idx: int, + vllm_config: VllmConfig = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + + self.num_heads = config.code_predictor_config.num_attention_heads + self.num_key_value_heads = config.code_predictor_config.num_key_value_heads + self.head_dim = getattr( + config.code_predictor_config, + "head_dim", + config.code_predictor_config.hidden_size // config.code_predictor_config.num_attention_heads, + ) + self.hidden_size = config.code_predictor_config.hidden_size + + if self.num_heads % self.num_key_value_heads != 0: + raise ValueError("num_attention_heads must be divisible by num_key_value_heads") + + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_key_value_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=True, + ) + self.o_proj = RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + disable_tp=True, + ) + self.rotary_emb = get_rope( + self.head_dim, + max_position=config.code_predictor_config.max_position_embeddings, + rope_parameters=None, + dual_chunk_attention_config=None, + ) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_key_value_heads * self.head_dim + + # Query/Key normalization + self.q_norm = RMSNorm(self.head_dim, eps=config.code_predictor_config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.code_predictor_config.rms_norm_eps) + self.is_causal = True + self.config = config + + self.attention_backends = ["flash_attention_2", "xformers", "eager", "sdpa"] + cudagraph_mode = get_current_vllm_config().compilation_config.cudagraph_mode + if "flash_attention_2" in ALL_ATTENTION_FUNCTIONS and cudagraph_mode.has_full_cudagraphs(): + logger.warning( + f"CUDAGraphMode.{cudagraph_mode.name} is currently not supported " + f"with flash attention for Qwen3-Omni talker MTP." + f"removing flash attention from attention_backends" + ) + self.attention_backends.remove("flash_attention_2") + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + use_cache: bool = False, + position_ids: torch.LongTensor | None = None, + ) -> torch.Tensor: + bsz, seq_len, _ = hidden_states.shape + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape for attention + q = q.reshape(bsz, seq_len, self.num_heads, self.head_dim) + k = k.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) + v = v.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + # Apply normalization + q = self.q_norm(q).contiguous() + k = self.k_norm(k).contiguous() + q = q.reshape(-1, self.q_size) + k = k.reshape(-1, self.kv_size) + + # Apply RoPE + q, k = self.rotary_emb(position_ids, q, k) + + # Reshape for attention + q = q.reshape(bsz, seq_len, self.num_heads, self.head_dim) + k = k.reshape(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + v_heads = v.transpose(1, 2).contiguous() + q_heads = q.transpose(1, 2).contiguous() + k_heads = k.transpose(1, 2).contiguous() + + if past_key_values is not None: + sin, cos = self.rotary_emb.get_cos_sin(seq_len) + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k_heads, v_heads = past_key_values.update(k_heads, v_heads, self.layer_idx, cache_kwargs) + + # Try attention backends in order of preference, with runtime error handling + # This handles cases where the backend is registered but not actually available + attn_output = None + last_error = None + + for backend_name in self.attention_backends: + if backend_name not in ALL_ATTENTION_FUNCTIONS: + continue + + try: + attention_interface = ALL_ATTENTION_FUNCTIONS[backend_name] + attn_output, _ = attention_interface( + self, + q_heads, + k_heads, + v_heads, + None, + dropout=0.0 if not self.training else getattr(self, "attention_dropout", 0.0), + scaling=self.head_dim**-0.5, + sliding_window=None, + use_cache=use_cache, + position_ids=position_ids[:seq_len].unsqueeze(0), + output_hidden_states=True, + output_attentions=False, + ) + break + except (ValueError, ImportError, RuntimeError, AttributeError) as e: + # Store error and try next backend + last_error = e + continue + + if attn_output is None: + raise RuntimeError( + f"All attention backends failed. Last error: {last_error}. " + "Please install flash-attn, or ensure PyTorch's scaled_dot_product_attention is available." + ) + attn_output = attn_output.reshape(*(hidden_states.shape[:-1]), -1).contiguous() + + attn_output, _ = self.o_proj(attn_output) + return attn_output + + +# ============================================================================ +# Code Predictor MLP Layer +# ============================================================================ + + +class Qwen3OmniCodePredictorMLP(nn.Module): + """Feed-forward network for code predictor with fused gate/up projection.""" + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.code_predictor_config.hidden_size + intermediate_size = config.code_predictor_config.intermediate_size + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size, intermediate_size], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + disable_tp=True, + ) + + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + gate, up = gate_up.chunk(2, dim=-1) + down, _ = self.down_proj(F.silu(gate) * up) + return down + + +# ============================================================================ +# MTP Layer (Multi-Token Prediction Layer) +# ============================================================================ + + +class Qwen3OmniCodePredictorMTPLayer(nn.Module): + """MTP layer for speculative decoding - predicts next residual code layer.""" + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + layer_idx: int, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + + self.self_attn = Qwen3OmniCodePredictorAttention( + config, + layer_idx, + vllm_config=type( + "VllmConfig", + (), + {"cache_config": cache_config, "quant_config": quant_config, "model_config": model_config}, + )(), + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Qwen3OmniCodePredictorMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm( + config.code_predictor_config.hidden_size, eps=config.code_predictor_config.rms_norm_eps + ) + self.post_attention_layernorm = RMSNorm( + config.code_predictor_config.hidden_size, eps=config.code_predictor_config.rms_norm_eps + ) + + def mtp_block( + self, + hidden_states: torch.Tensor, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + use_cache: bool = False, + position_ids: torch.LongTensor | None = None, + ) -> torch.Tensor: + # Self-attention with residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, past_key_values, cache_position, use_cache, position_ids) + hidden_states = residual + hidden_states + + # MLP with residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3OmniCodePredictorBaseModel(nn.Module): + """ + Base model for code predictor - matches HF Qwen3OmniMoeTalkerCodePredictorModel structure. + + This is a simple transformer that processes inputs_embeds and outputs hidden states. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config.code_predictor_config + + self.config = config + self.vocab_size = config.vocab_size + self.num_code_groups = config.num_code_groups + + # Codec embeddings (for layers 1-num_code_groups-1) + self.codec_embedding = nn.ModuleList( + [ + VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + for _ in range(config.num_code_groups - 1) + ] + ) + + # Decoder layers + self.layers = nn.ModuleList( + [ + Qwen3OmniCodePredictorMTPLayer( + vllm_config.model_config.hf_config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + layer_idx=idx, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(config.num_hidden_layers) + ] + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Any | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + **kwargs: Any, + ) -> Any: + """ + Forward pass matching HF structure. + + Args: + inputs_embeds: [batch, seq_len, hidden_size] + position_ids: Optional position IDs tensor + past_key_values: Optional cached key-value pairs + use_cache: Whether to use cache + cache_position: Optional cache position tensor + **kwargs: Additional keyword arguments + + Returns: + Named tuple with .last_hidden_state and .past_key_values attributes + """ + batch_size, seq_len, _ = inputs_embeds.shape + # Forward through decoder layers + hidden_states = inputs_embeds + + for layer in self.layers: + hidden_states = layer.mtp_block(hidden_states, past_key_values, cache_position, use_cache, position_ids) + + # Final norm + hidden_states = self.norm(hidden_states) + + # Return in HF-compatible format + Output = namedtuple("Output", ["last_hidden_state", "past_key_values"]) + return Output(last_hidden_state=hidden_states, past_key_values=None) # [batch, num_code_groups-1, hidden_size] + + def get_input_embeddings(self): + """Return codec embeddings for HF compatibility.""" + return self.codec_embedding + + +def code_predictor_sample( + logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + logits = self.logits_processors(None, logits[:, -1]) + probs = F.softmax(logits, dim=-1) + code = torch.multinomial(probs.squeeze(1), num_samples=1) # [batch, 1] + return code + + +def code_predictor_sample_fake( + logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty((logits.shape[0], 1), dtype=torch.int64, device=logits.device) + + +direct_register_custom_op( + op_name="qwen3_omni_code_predictor_sample", + op_func=code_predictor_sample, + fake_impl=code_predictor_sample_fake, +) + + +@support_torch_compile +class Qwen3OmniMoeTalkerCodePredictor(nn.Module): + """ + Code predictor wrapper matching HF structure. + + Structure: + - self.model: Qwen3OmniCodePredictorBaseModel (transformer) + - self.lm_head: ModuleList of output heads + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + talker_code_predictor_config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.prefix = prefix + + self.config = talker_code_predictor_config + self.vocab_size = self.config.code_predictor_config.vocab_size + self.num_code_groups = self.config.code_predictor_config.num_code_groups + + # Base transformer model (matches HF structure) + self.model = Qwen3OmniCodePredictorBaseModel(vllm_config=vllm_config, prefix=prefix) + + # Output heads for each residual layer (1-num_layers-1) + self.lm_head = nn.ModuleList( + [ + nn.Linear( + self.config.code_predictor_config.hidden_size, + self.config.code_predictor_config.vocab_size, + bias=False, + ) + for _ in range(self.num_code_groups - 1) + ] + ) + self.logits_processors = LogitsProcessorList( + [ + TopKLogitsWarper(top_k=50), + TopPLogitsWarper(top_p=0.8), + ] + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + def forward( + self, + layer0_code: torch.Tensor, + layer0_embed: torch.Tensor, + last_talker_hidden: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for code predictor. + + Args: + layer0_code: + Code index for code-group (layer) 0. + Shape: [batch_size, 1], dtype typically int64. + + last_talker_hidden: + + Shape: [batch_size, hidden_size]. + + Returns: + pos_all_layers: + Predicted codes for all code groups, including `layer0_code`. + Shape: [batch_size, num_code_groups, 1]. + + current_input: + The final input embedding sequence after appending embeddings of all + predicted codes (one token per predicted layer). + Shape: [batch_size, num_code_groups + 2, hidden_size]. + """ + pos_codes = [layer0_code] # Start with layer 0: [batch, 1] + try: + current_input = torch.cat([last_talker_hidden, layer0_embed], dim=1) # [batch, 2, hidden_size] + except Exception as e: + print(f"Error in current_input: {e}") + print(f"last_talker_hidden shape: {last_talker_hidden.shape}") + print(f"prev_embed shape: {layer0_embed.shape}") + raise e + batch_size = current_input.shape[0] + + # Predict all residual layers (layers 1 to num_code_groups-1) autoregressively + for layer_idx in range(self.num_code_groups - 1): + seq_len = layer_idx + 2 + # Compute position_ids dynamically to avoid torch.compile specializing batch_size + position_ids = torch.arange(seq_len, device=current_input.device, dtype=torch.int64).repeat(batch_size) + # Forward through code_predictor model + outputs = self.model( + inputs_embeds=current_input, + attention_mask=None, + position_ids=position_ids, + past_key_values=None, + use_cache=False, + cache_position=None, + ) + hidden_state = outputs.last_hidden_state # [batch, 2, hidden_size] + + # Use the corresponding lm_head for this layer + logits = self.lm_head[layer_idx](hidden_state[:, -1:, :]) + code = torch.ops.vllm.qwen3_omni_code_predictor_sample(logits, self.layer_name) + pos_codes.append(code) + # Update prev_embed for next layer (if not last layer) + # layer_idx=0 predicts layer 1, embed with codec_embedding[1] + new_embed = self.model.codec_embedding[layer_idx](code) # [batch, 1, hidden_size] + current_input = torch.cat([current_input, new_embed], dim=1) # [batch, 3~n, hidden_size] + pos_all_layers = torch.stack(pos_codes, dim=1) # [batch, num_code_groups, 1] + return pos_all_layers, current_input + + def load_weights(self, weights: list[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with mapping for fused QKV and gate_up projections. + + Maps original HF weights (q_proj, k_proj, v_proj, gate_proj, up_proj) + to fused vLLM weights (qkv_proj, gate_up_proj). + """ + # Mapping for fused projections + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Skip rotary embeddings + if "rotary_emb.inv_freq" in name: + continue + + # Handle stacked/fused parameters + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip if parameter doesn't exist (e.g., bias) + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Non-stacked parameters - use default loading + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is not None: + weight_loader(param, loaded_weight) + else: + param.data.copy_(loaded_weight) + + loaded_params.add(name) + + return loaded_params diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py new file mode 100644 index 0000000000000000000000000000000000000000..946694ed2294d107c3f7c7fce4c016fed5f9811c --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -0,0 +1,733 @@ +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeTalkerConfig, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder, +) +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsPP, +) +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniThinkerDummyInputsBuilder, +) +from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock +from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_code_predictor_mtp import ( + Qwen3OmniMoeTalkerCodePredictor, +) +from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker import ( + Qwen3MoeLLMForCausalLM, + Qwen3OmniMoeConditionalGenerationMixin, + Qwen3OmniMoeThinkerMultiModalProcessor, + Qwen3OmniMoeThinkerProcessingInfo, +) + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + + +logger = init_logger(__name__) + +Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3OmniMoeThinkerMultiModalProcessor, + info=Qwen3OmniMoeThinkerProcessingInfo, + dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder, +) +class Qwen3OmniMoeTalkerForConditionalGeneration( + nn.Module, + # SupportsMultiModal, + SupportsPP, + Qwen3OmniMoeConditionalGenerationMixin, +): + """ + Qwen3 Omni MoE Talker - Converts text to audio codec codes. + + The talker is the second stage of Qwen3 Omni MoE's TTS pipeline: + 1. Thinker: Generates text response + hidden states + 2. Talker: Converts those to 8-layer audio codec codes + 3. Code2Wav: Converts codes to waveform + + ## Key Components: + - text_projection: Projects thinker text embeddings → talker dimension + - hidden_projection: Projects thinker hidden states → talker dimension + - language_model: Main MoE transformer (generates layer 0) + - codec_head: Projects to codec vocabulary (layer 0 logits) + - code_predictor: Small transformer for layers 1-num_layers-1 + """ + + logger = init_logger(__name__) + + # Weight mapping from HuggingFace to vLLM naming convention + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Main MoE transformer model + "talker.model.": "language_model.model.", + # Codec head remains separate (outputs audio codes, not text) + "talker.codec_head.": "codec_head.", + # Code predictor: Now matches HF structure exactly (has .model sub-module) + # e.g., "talker.code_predictor.model.codec_embedding.0" → "code_predictor.model.codec_embedding.0" + "talker.code_predictor.": "code_predictor.", + # Projection layers + "talker.text_projection.": "text_projection.", + "talker.hidden_projection.": "hidden_projection.", + # Fallback: strip talker prefix + "talker.": "", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + talker_config: Qwen3OmniMoeTalkerConfig = vllm_config.model_config.hf_config + talker_config.text_config.rope_parameters = talker_config.text_config.rope_scaling + talker_config.text_config.rope_parameters["rope_theta"] = talker_config.text_config.rope_theta + self.quant_config = vllm_config.quant_config + self.prefix = prefix + self.vllm_config = vllm_config + self.config = talker_config + self.vocab_size = talker_config.text_config.vocab_size + self.router_aux_loss_coef = talker_config.text_config.router_aux_loss_coef + self.num_experts = talker_config.text_config.num_experts + self.num_experts_per_tok = talker_config.text_config.num_experts_per_tok + # thinker projection components for talker + self.text_projection = Qwen3OmniMoeTalkerResizeMLP(self.config) + self.hidden_projection = Qwen3OmniMoeTalkerResizeMLP(self.config) + self.codec_head = nn.Linear(self.config.text_config.hidden_size, self.config.text_config.vocab_size, bias=False) + + self.rope_deltas = None + self.spatial_merge_size = self.config.spatial_merge_size + + self.vocab_size = self.config.code_predictor_config.vocab_size + self.num_code_groups = self.config.code_predictor_config.num_code_groups + + self.language_model = Qwen3OmniMoeModel( + vllm_config=vllm_config, + talker_config=self.config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.code_predictor = Qwen3OmniMoeTalkerCodePredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "code_predictor") + ) + max_batch_size = max( + vllm_config.scheduler_config.max_num_seqs, vllm_config.compilation_config.max_cudagraph_capture_size + ) + self.layer0_embed_buffer = torch.zeros( + (max_batch_size, 1, self.config.text_config.hidden_size), + dtype=vllm_config.model_config.dtype, + ) + + def code_predictor_forward( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + *, + temperature: float = 1.0, + top_k: int = 50, # Match transformers default + top_p: float = 0.8, # Match transformers default + generation_steps: int | None = None, + last_talker_hidden: torch.Tensor | None = None, + **_: object, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate full RVQ codec codes for the provided sequence. + + The code predictor consumes the layer-0 codec codes produced by the talker + alongside the talker's hidden states, and autoregressively predicts the remaining + residual layers (to num_codec_groups). + + Returns: + tuple containing: + - residual_codes: A tensor of shape [batch, num_code_groups, seq_len] containing + the complete set of codec codes + - summed_embeddings: A tensor of shape [batch, seq_len, hidden_size] + Sum of all layer embeddings at each position (like Transformers) + """ + if input_ids is None: + raise ValueError("`input_ids` containing layer-0 codec codes must be provided.") + if inputs_embeds is None: + raise ValueError("`inputs_embeds` containing talker hidden states must be provided.") + + if inputs_embeds.ndim == 2: + inputs_embeds = inputs_embeds.unsqueeze(0) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + # Ensure the tensors are contiguous for the autoregressive sampling loop + inputs_embeds = inputs_embeds.contiguous() + input_ids = input_ids.contiguous() + + # Generate full codec codes using MTP + # This will be the parallel prediction implementation + batch_size, seq_len = input_ids.shape + + # For now, use sequential generation (TODO: implement parallel) + # Result will be [batch, num_code_groups, seq_len] + # - all_codes_per_position will collect [batch, num_code_groups, 1] for each position + all_codes_per_position = [] + middle_hidden_states = [] # Collect hidden states for each position + + # Generate residual layers for each position + for pos in range(seq_len): + layer0_code = input_ids[:, pos : pos + 1] # [batch, 1] + + # Initial input: [last_talker_hidden, layer0_embed] + layer0_embed = self.embed_input_ids(layer0_code) + self.layer0_embed_buffer[:batch_size].copy_(layer0_embed) + pos_all_layers, current_input = self.code_predictor( + layer0_code, self.layer0_embed_buffer[:batch_size], last_talker_hidden + ) + + # Stack all layers for this position: [batch, num_code_groups, 1] + all_codes_per_position.append(pos_all_layers) + middle_hidden_states.append(current_input[:, 2:-1, :]) + + # Concatenate across positions: [batch, num_code_groups, seq_len] + result_codes = torch.cat(all_codes_per_position, dim=2) + + # Build summed embeddings for each position (like Transformers) + # This combines layer-0 embed, mid layers hidden states, and last layer embed + all_summed_embeddings = [] + + for pos in range(seq_len): + # Layer 0 embedding + layer0_code = result_codes[:, 0, pos : pos + 1] # [batch, 1] + layer0_embed = self.embed_input_ids(layer0_code) # [batch, 1, hidden_size] + + # mid layers hidden states (from CodePredictor) + mid_residual_hiddens = middle_hidden_states[pos] # [batch, num_code_groups-2, hidden_size] + mid_list = list(mid_residual_hiddens.split(1, dim=1)) + + # last layer embedding + last_layer_code = result_codes[:, -1, pos : pos + 1] # [batch, 1] + last_residual_hidden = self.code_predictor.model.codec_embedding[-1](last_layer_code) + + # Concatenate all layers: [batch, num_code_groups, hidden_size] + pos_codec_hiddens = torch.cat( + [layer0_embed] + mid_list + [last_residual_hidden], + dim=1, + ) + + # Sum across layers: [batch, 1, hidden_size] (like Transformers) + pos_summed = pos_codec_hiddens.sum(dim=1, keepdim=True) + all_summed_embeddings.append(pos_summed) + + # Concatenate across positions: [batch, seq_len, hidden_size] + summed_embeddings = torch.cat(all_summed_embeddings, dim=1).squeeze(1) + + return result_codes, summed_embeddings + + def init_multi_modal(self, thinker_config: Any) -> None: + """ + Initialize multimodal components from the thinker. + + Unlike Qwen2.5 Omni which creates audio_tower and visual encoders here, + Qwen3 Omni MoE has a cleaner separation: the thinker is the ONLY module + that processes raw multimodal inputs. The talker only handles text-to-audio + conversion using pre-processed embeddings from the thinker. + + This method exists for API compatibility and stores the thinker config + for reference. The actual multimodal processing components (audio_tower, + visual) are ONLY in the thinker, not duplicated in the talker. + + Args: + thinker_config: Configuration from the thinker model (for reference only) + """ + self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) + self.visual = Qwen3Omni_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=self.quant_config, + prefix=maybe_prefix(self.prefix, "visual"), + # attn_backend_override=attn_backend_override, + ) + + def project_thinker_outputs( + self, + thinker_embeds: torch.Tensor | None = None, + thinker_hidden_states: torch.Tensor | None = None, + is_multimodal_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Project thinker outputs to talker's hidden dimension. + + The talker has a different hidden size than the thinker, so we need + to project the inputs appropriately: + - Text embeddings (from thinker's embedding layer) → text_projection + - Hidden states (from thinker's last layer, for multimodal) → hidden_projection + + Args: + thinker_embeds: Text embeddings from thinker [batch, seq, thinker_hidden] + thinker_hidden_states: Hidden states from thinker's last layer [batch, seq, thinker_hidden] + is_multimodal_mask: Boolean mask indicating multimodal positions [batch, seq] + + Returns: + projected_embeds: [batch, seq, talker_hidden] + """ + if thinker_embeds is None and thinker_hidden_states is None: + raise ValueError("Either thinker_embeds or thinker_hidden_states must be provided") + + # If only embeddings provided, project all as text + if thinker_hidden_states is None or is_multimodal_mask is None: + return self.text_projection(thinker_embeds) + + # If only hidden states provided, project all as hidden + if thinker_embeds is None: + return self.hidden_projection(thinker_hidden_states) + + # Mixed case: use mask to decide which projection + batch_size, seq_len, _ = thinker_embeds.shape + output = torch.empty( + (batch_size, seq_len, self.config.text_config.hidden_size), + device=thinker_embeds.device, + dtype=thinker_embeds.dtype, + ) + + # Project multimodal regions using hidden states + if is_multimodal_mask.any(): + mm_hidden = thinker_hidden_states[is_multimodal_mask] + projected_mm = self.hidden_projection(mm_hidden) + output[is_multimodal_mask] = projected_mm + + # Project text regions using embeddings + if (~is_multimodal_mask).any(): + text_embeds = thinker_embeds[~is_multimodal_mask] + projected_text = self.text_projection(text_embeds) + output[~is_multimodal_mask] = projected_text + + return output + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + """Forward pass through the talker model.""" + talker_hidden_states, _ = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return talker_hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + """Compute logits for audio codec codes (layer 0 of RVQ). + + This projects the hidden states to the codec vocabulary space. + For full audio generation, layers except 0 would be predicted by + the code_predictor after sampling. + """ + logits = self.codec_head(hidden_states) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + """Create empty intermediate tensors for pipeline parallelism.""" + return self.language_model.make_empty_intermediate_tensors(batch_size, dtype, device) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds") and "image" not in mm_input_by_modality: + mm_input_by_modality["image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds") and "video" not in mm_input_by_modality: + mm_input_by_modality["video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features") and "audio" not in mm_input_by_modality: + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return [] + + logger.warning( + "\n\n\n" + "THIS FUNCTION RETURNS DUMMY MULTIMODAL EMBEDDINGS FOR PROFILE RUN, " + "SHOULD NOT BE CALLED IN INFERENCE." + "\n\n\n" + ) + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + dummy_multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + # TODO: do projection for all multimodel + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + image_embeddings = self._process_image_input(multimodal_input) + dummy_image_embeddings = () + for image_embed in image_embeddings: + dummy_image_embeddings += ( + torch.zeros( + image_embed.shape[0], + self.config.text_config.hidden_size, + device=image_embed.device, + dtype=torch.bfloat16, + ), + ) + dummy_multimodal_embeddings += tuple(image_embeddings) + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + dummy_video_video_embeddings = () + for video_embed in video_embeddings: + dummy_video_video_embeddings += ( + torch.zeros( + video_embed.shape[0], + self.config.text_config.hidden_size, + device=video_embed.device, + dtype=torch.bfloat16, + ), + ) + dummy_multimodal_embeddings += tuple(dummy_video_video_embeddings) + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + dummy_audio_embeddings = () + for audio_embed in audio_embeddings: + dummy_audio_embeddings += ( + torch.zeros( + audio_embed.shape[0], + self.config.text_config.hidden_size, + device=audio_embed.device, + dtype=torch.bfloat16, + ), + ) + dummy_multimodal_embeddings += tuple(dummy_audio_embeddings) + return dummy_multimodal_embeddings + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + is_multimodal: bool = False, + ): + """Get the input embedding layer (for codec tokens).""" + return self.language_model.embed_input_ids(input_ids) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for the talker model. + + The weight mapping translates from HuggingFace naming convention + to vLLM's internal structure. Code predictor weights are routed + to its custom loader for vocab extension support. + """ + loader = AutoWeightsLoader( + self, + skip_prefixes=["thinker.", "code2wav."], + # "code_predictor."], + ) + # Don't apply mapper again since we already did it + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + # Log load summary + try: + total_bytes = 0 + for name, param in self.named_parameters(): + if param is not None and param.data is not None: + total_bytes += param.data.numel() * param.data.element_size() + device = next(self.parameters()).device + logger.info( + "[Model Loaded] name=%s, success=%s, size=%.2f MB, device=%s", + self.__class__.__name__, + True, + total_bytes / (1024**2), + str(device), + ) + except Exception: + logger.error("Error logging model load summary") + + multi_model_weights = set() + for name, param in self.visual.named_parameters(): + multi_model_weights.add("visual." + name) + for name, param in self.audio_tower.named_parameters(): + multi_model_weights.add("audio_tower." + name) + loaded.update(multi_model_weights) + + return loaded + + +class Qwen3OmniMoeTalkerResizeMLP(nn.Module): + """ + MLP for projecting between thinker and talker hidden dimensions. + + The thinker and talker have different hidden sizes: + - Thinker: config.thinker_hidden_size (e.g., 3584) + - Talker: config.text_config.hidden_size (e.g., 2048) + + This MLP projects from thinker → talker dimension. + Two instances are used: + - text_projection: For text embeddings from thinker's embedding layer + - hidden_projection: For hidden states from thinker's last transformer layer + """ + + def __init__(self, config: Qwen3OmniMoeTalkerConfig): + super().__init__() + self.linear_fc1 = nn.Linear(config.thinker_hidden_size, config.text_config.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(config.text_config.intermediate_size, config.text_config.hidden_size, bias=True) + self.act_fn = _ACTIVATION_REGISTRY[config.text_config.hidden_act] # silu + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module): + """ + Wrapper that combines shared_expert MLP with its sigmoid gate. + + This matches the HuggingFace weight structure where: + - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight + - mlp.shared_expert_gate.weight (sibling, not child) + + The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x). + + It also exposes the underlying shared_expert interface to keep + compatibility with backends that split shared-expert computation. + """ + + def __init__( + self, + shared_expert: Qwen3MoeMLP, + shared_expert_gate: nn.Linear, + ): + super().__init__() + self._shared_expert = shared_expert + self._shared_expert_gate = shared_expert_gate + + @property + def gate_up_proj(self): + return self._shared_expert.gate_up_proj + + @property + def down_proj(self): + return self._shared_expert.down_proj + + @property + def act_fn(self): + return self._shared_expert.act_fn + + def expert_gate(self, x: torch.Tensor): + gate_out = self._shared_expert_gate(x) + if isinstance(gate_out, tuple): + return gate_out + return gate_out, None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self._shared_expert(x) + gate_out = self._shared_expert_gate(x) + if isinstance(gate_out, tuple): + gate_out = gate_out[0] + gate_values = F.sigmoid(gate_out) # [batch, 1] + return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden] + + +class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module): + """ + Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support. + + This block uses SharedFusedMoE to efficiently compute both routed experts + and the shared expert, potentially overlapping computation with communication. + + Weight structure matches HuggingFace: + - mlp.gate.weight (router) + - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight + - mlp.shared_expert_gate.weight + - mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight + """ + + def __init__( + self, + config: Qwen3OmniMoeTalkerConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + text_config = config.text_config + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > text_config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}." + ) + + # Router gate for selecting top-k experts + self.gate = ReplicatedLinear( + text_config.hidden_size, + text_config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + # Shared expert MLP (matches HF: mlp.shared_expert.*) + if text_config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3MoeMLP( + hidden_size=text_config.hidden_size, + intermediate_size=text_config.shared_expert_intermediate_size, + hidden_act=text_config.hidden_act, + quant_config=quant_config, + reduce_results=False, # Don't reduce, we'll handle it + prefix=f"{prefix}.shared_expert", + ) + # Shared expert gate (matches HF: mlp.shared_expert_gate.weight) + # This is a sibling of shared_expert, not a child + self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False) + # Create wrapper for SharedFusedMoE + self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper( + self.shared_expert, self.shared_expert_gate + ) + else: + self.shared_expert = None + self.shared_expert_gate = None + self._shared_expert_wrapper = None + + # Fused MoE with shared expert support + self.experts = SharedFusedMoE( + shared_experts=self._shared_expert_wrapper, + num_experts=text_config.num_experts, + top_k=text_config.num_experts_per_tok, + hidden_size=text_config.hidden_size, + intermediate_size=text_config.moe_intermediate_size, + reduce_results=False, # We'll reduce manually after combining + renormalize=text_config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # Compute router logits + router_logits, _ = self.gate(hidden_states) + + # Forward through SharedFusedMoE + # Returns (shared_out, fused_out) when shared_expert is present + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) + + # Combine shared and routed expert outputs + if self._shared_expert_wrapper is not None: + # SharedFusedMoE returns tuple: (shared_out, fused_out) + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + + # Apply tensor parallel reduction if needed + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM): + """ + Qwen3 Omni MoE Talker language model. + + This model extends Qwen3MoeLLMForCausalLM with: + - Shared expert support via SharedFusedMoE + - Codec embedding instead of text embedding + - No LM head (codec head is separate in the parent class) + """ + + def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str): + # Create a vllm_config for the talker's text model + talker_vllm_config = vllm_config.with_hf_config( + talker_config.text_config, architectures=["Qwen3MoeForCausalLM"] + ) + talker_vllm_config.model_config.hf_text_config = talker_vllm_config.model_config.hf_config + + super().__init__( + vllm_config=talker_vllm_config, + prefix=prefix, + ) + + self.config = talker_config + self.talker_vllm_config = talker_vllm_config + + # Remove the inherited LM head so the talker only exposes codec outputs. + if hasattr(self, "lm_head"): + del self.lm_head + + # Replace the base embed tokens with codec embedding. + if hasattr(self.model, "embed_tokens"): + del self.model.embed_tokens + + # Codec embedding for RVQ code generation + self.model.codec_embedding = nn.Embedding( + talker_config.text_config.vocab_size, + talker_config.text_config.hidden_size, + ) + + # Replace MoE blocks with shared expert versions + self._replace_moe_blocks_with_shared_expert(prefix) + + def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None: + """ + Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock + that includes shared expert support via SharedFusedMoE. + """ + # Get compilation config to clean up registered layer names + compilation_config = self.talker_vllm_config.compilation_config + + for layer_idx, layer in enumerate(self.model.layers): + # Check if this layer has a MoE block (has experts attribute) + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + # Remove old layer registration from static_forward_context + old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts" + if old_experts_prefix in compilation_config.static_forward_context: + del compilation_config.static_forward_context[old_experts_prefix] + + # Create new MoE block with shared expert support + layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock( + config=self.config, + quant_config=self.talker_vllm_config.quant_config, + prefix=f"{prefix}.model.layers.{layer_idx}.mlp", + ) + + def embed_input_ids( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + """Embed codec input IDs.""" + return self.model.codec_embedding(input_ids) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py new file mode 100644 index 0000000000000000000000000000000000000000..24f2aecc143568545177262fa2654f6b0d952ab5 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py @@ -0,0 +1,1116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3-Omni-Moe model (thinker part).""" + +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from packaging.version import Version +from transformers import PretrainedConfig +from transformers import __version__ as TRANSFORMERS_VERSION +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeConfig, + Qwen3OmniMoeThinkerConfig, +) +from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( + Qwen3OmniMoeProcessor, +) +from transformers.models.whisper import WhisperFeatureExtractor +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniAudioFeatureInputs, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, +) +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLProcessingInfo, +) +from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo +from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from vllm.model_executor.models.qwen3_moe import Qwen3MoeModel as _Qwen3MoeLLMModel +from vllm.model_executor.models.qwen3_omni_moe_thinker import ( + Qwen3Omni_VisionTransformer, + Qwen3OmniMoeAudioEncoder, + _get_feat_extract_output_lengths, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from vllm.model_executor.models.vision import ( + get_llm_pos_ids_for_vision, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems +from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems +from vllm.multimodal.processing.processor import ( + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin, +) + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + +logger = init_logger(__name__) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + "deepstack_input_embeds": 0, + } +) +class Qwen3MoeLLMModel(_Qwen3MoeLLMModel): + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + *, + capture_layer_indices: Sequence[int] | None = None, + return_hidden_states: bool = False, + deepstack_input_embeds: IntermediateTensors | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + capture_set = set(capture_layer_indices) if capture_layer_indices else None + captured_hidden_states: dict[str, torch.Tensor] | None = {} if return_hidden_states else None + + for layer_idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + layer_idx = layer_idx + self.start_layer + + if captured_hidden_states is not None and capture_set is not None: + if layer_idx in capture_set: + captured_hidden_states[str(layer_idx)] = hidden_states.clone().view(-1, hidden_states.shape[-1]) + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and layer_idx in range(0, len(deepstack_input_embeds)): + hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) + hidden_states, _ = self.norm(hidden_states, residual) + if captured_hidden_states is not None: + return hidden_states, captured_hidden_states + else: + return hidden_states, None + + +class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3MoeForCausalLM, self).__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeLLMModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + +class Qwen3OmniMoeThinkerProcessingInfo(Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3OmniMoeConfig).thinker_config + + def get_hf_processor(self, **kwargs: object) -> Qwen3OmniMoeProcessor: + processor = self.ctx.get_hf_processor( + Qwen3OmniMoeProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + if not hasattr(processor, "audio_token"): + processor.audio_token = "<|audio_pad|>" + if not hasattr(processor, "image_token"): + processor.image_token = "<|image_pad|>" + if not hasattr(processor, "video_token"): + processor.video_token = "<|video_pad|>" + return processor + + def get_feature_extractor(self, **kwargs: object): + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None, "image": None, "video": None} + + +Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder + + +class Qwen3OmniMoeThinkerMultiModalProcessor( + Qwen2_5OmniThinkerMultiModalProcessor, +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: + length = x.shape[-1] + if length % hop_length != 0: + pad_length = hop_length - (length % hop_length) + x = np.pad(x, (0, pad_length), mode="constant", constant_values=0) + return x + + # NOTE: WhisperFeatureExtractor cannot handle empty list of audios + feature_extractor = self.info.get_feature_extractor() + hop_length = feature_extractor.hop_length + if audios: + # NOTE: Qwen3-Omni processor accept "audio" + # To make sure the cache works with padding=True, we pre-padded + # the audio to multiple of hop_length. + mm_data["audio"] = [ + pad_to_hop_length(audio, hop_length) + if isinstance(audio, np.ndarray) + else (pad_to_hop_length(audio[0], hop_length), audio[1]) + for audio in audios + ] + + # TODO(Isotr0py): Remove this patch after upstream fix PR + # released and Transformers version update: + # https://github.com/huggingface/transformers/pull/41473 + mm_kwargs = dict(mm_kwargs) + tok_kwargs = dict(tok_kwargs) + mm_kwargs["audio_kwargs"] = dict(mm_kwargs.get("audio_kwargs") or {}) + mm_kwargs["text_kwargs"] = dict(mm_kwargs.get("text_kwargs") or {}) + if Version(TRANSFORMERS_VERSION) < Version("4.58.0"): + # Extract audio_sample_rate before restructuring + audio_sample_rate = mm_kwargs.pop("audio_sample_rate", None) + + # move truncation to audio_kwargs level to avoid conflict + # with tok_kwargs + mm_kwargs["audio_kwargs"].setdefault("truncation", mm_kwargs.pop("truncation", False)) + mm_kwargs["text_kwargs"].setdefault("truncation", tok_kwargs.pop("truncation", False)) + + # Validate and conditionally pass audio_sample_rate + # WhisperFeatureExtractor has a fixed sampling rate, and vLLM's + # audio loader already resamples audio to the target rate. + # Only pass the value if it matches to avoid unexpected behavior. + if audio_sample_rate is not None: + expected_sr = feature_extractor.sampling_rate + if audio_sample_rate != expected_sr: + logger.warning( + "[%s] audio_sample_rate mismatch: user provided %dHz " + "but model expects %dHz. Ignoring user value. " + "vLLM's audio loader already resampled to %dHz.", + self.__class__.__name__, + audio_sample_rate, + expected_sr, + expected_sr, + ) + else: + # Sample rate matches, safe to pass + mm_kwargs["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate + + hf_inputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + if ( + "audio_feature_lengths" in hf_inputs + and "feature_attention_mask" in hf_inputs + and (audios := mm_data.get("audio", [])) + ): + audio_num_frames = [] + for _, audio in enumerate(audios): + audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio) + num_frame = ( + (audio_length // hop_length) if audio_length % hop_length == 0 else (audio_length // hop_length - 1) + ) + if mm_kwargs.get("truncation", False): + num_frame = min(num_frame, feature_extractor.n_samples // hop_length) + audio_num_frames.append(num_frame) + hf_inputs["feature_attention_mask"] = [torch.ones(num_frame) for num_frame in audio_num_frames] + hf_inputs["audio_feature_lengths"] = torch.tensor(audio_num_frames) + return hf_inputs + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + prompt_ids: list[int], + mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + """ + Qwen3-Omni reimplements this function to handle `use_audio_in_video`. + """ + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + use_audio_in_video = False + if "video" in mm_kwargs: + for item in mm_kwargs["video"]: + if item and item["use_audio_in_video"].data: + use_audio_in_video = True + else: + use_audio_in_video = False + + # normal case with `use_audio_in_video=False` + if is_update_applied: + mm_placeholders = self._find_mm_placeholders( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + ) + else: + if use_audio_in_video and "audio" in mm_prompt_updates: + filtered_updates = {k: v for k, v in mm_prompt_updates.items() if k != "audio"} + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + filtered_updates, + ) + # Derive audio placeholders from video placeholders + mm_placeholders = self._derive_audio_from_video_placeholders(mm_placeholders, mm_prompt_updates) + else: + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) + + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + ) + + return prompt_ids, mm_placeholders + + def get_updates_use_audio_in_video( + self, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: list[int] | torch.Tensor, + video_second_per_grid_t: float, + ) -> list[int]: + shift = 0 + audio_token_id = thinker_config.audio_token_id + video_token_id = thinker_config.video_token_id + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + position_id_per_seconds = thinker_config.position_id_per_seconds + audio_token_indices = np.arange(next(iter([audio_len]))) + curr_video_grid_thw = next(iter([video_grid_thw])) + height = curr_video_grid_thw[1] // spatial_merge_size + width = curr_video_grid_thw[2] // spatial_merge_size + video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1) + video_token_indices = np.broadcast_to( + video_token_indices, (video_token_indices.shape[0], height, width) + ).reshape(-1) + video_token_indices = ( + (video_token_indices + shift) * next(iter([video_second_per_grid_t])) * position_id_per_seconds + ) + video_data_index, audio_data_index = 0, 0 + updates = [audio_start_token_id] + while video_data_index < len(video_token_indices) and audio_data_index < len(audio_token_indices): + if video_token_indices[video_data_index] <= audio_token_indices[audio_data_index]: + updates += [video_token_id] + video_data_index += 1 + else: + updates += [audio_token_id] + audio_data_index += 1 + if video_data_index < len(video_token_indices): + updates += [video_token_id] * (len(video_token_indices) - video_data_index) + if audio_data_index < len(audio_token_indices): + updates += [audio_token_id] * (len(audio_token_indices) - audio_data_index) + updates += [audio_end_token_id] + return updates + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + image_token = processor.image_token + video_token = processor.video_token + audio_token_id = vocab[audio_token] + image_token_id = vocab[image_token] + video_token_id = vocab[video_token] + + out_mm_data = out_mm_kwargs.get_data() + audio_feature_lengths = out_mm_data.get("audio_feature_lengths") + feature_attention_mask = out_mm_data.get("feature_attention_mask") + if audio_feature_lengths is None and feature_attention_mask is None: + audio_output_lengths = [] + elif audio_feature_lengths is not None: + audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths) + audio_output_lengths = audio_output_lens.tolist() + elif feature_attention_mask is not None: + assert isinstance(feature_attention_mask, torch.Tensor) + audio_output_lens = _get_feat_extract_output_lengths(feature_attention_mask.sum(-1)) + audio_output_lengths = audio_output_lens.tolist() + + # number of audios read from video. + audio_in_video_item_idx = 0 + audio_item_idx = 0 + + def get_replacement_qwen2_audio(item_idx: int): + nonlocal audio_item_idx + item_idx += audio_in_video_item_idx + + audio_item_idx += 1 + + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short to be represented inside the model" + ) + + return [audio_token_id] * num_features + + def get_replacement_qwen2_vision(item_idx: int, modality: str): + grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + merge_length = image_processor.merge_size**2 + + token_id = image_token_id if modality == "image" else video_token_id + return [token_id] * (int(grid_thw.prod()) // merge_length) + + use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) + thinker_config = self.info.get_hf_config() + + def get_replacement_qwen2_use_audio_in_video(item_idx: int): + nonlocal audio_in_video_item_idx + audio_num_features = audio_output_lengths[audio_in_video_item_idx + item_idx] + video_grid_thw = out_mm_data["video_grid_thw"][item_idx] + + audio_in_video_item_idx += 1 + + second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None) + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[item_idx] + else: + video_second_per_grid_t = 2.0 + + placeholder = self.get_updates_use_audio_in_video( + thinker_config=thinker_config, + audio_len=audio_num_features, + video_grid_thw=video_grid_thw, + video_second_per_grid_t=video_second_per_grid_t, + ) + return PromptUpdateDetails.select_token_id(placeholder, embed_token_id=video_token_id) + + video_replacement_fn = ( + get_replacement_qwen2_use_audio_in_video + if use_audio_in_video + else partial(get_replacement_qwen2_vision, modality="video") + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ), + PromptReplacement( + modality="image", + target=image_token, + replacement=partial(get_replacement_qwen2_vision, modality="image"), + ), + PromptReplacement( + modality="video", + target=video_token, + replacement=video_replacement_fn, + ), + ] + + def _derive_audio_from_video_placeholders( + self, + placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + """ + Helper to derive audio placeholders from video placeholders when + use_audio_in_video=True. + """ + if "video" not in placeholders: + return placeholders + + # Validate audio and video counts match + num_videos = len(placeholders["video"]) + num_audios = len(mm_prompt_updates.get("audio", [])) + if num_audios != num_videos: + raise ValueError( + f"use_audio_in_video requires equal number of audio and video items, got {num_audios=}, {num_videos=}" + ) + + tokenizer = self.info.get_tokenizer() + processor = self.info.get_hf_processor() + audio_token_id = tokenizer.get_vocab()[processor.audio_token] + + result_placeholders = dict(placeholders) + audio_placeholders = [] + + # Each video is paired with one audio + for video_idx, video_placeholder in enumerate(placeholders["video"]): + # Create is_embed mask selecting only audio tokens + audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id + + audio_placeholder = PlaceholderFeaturesInfo( + modality="audio", + item_idx=video_idx, + start_idx=video_placeholder.start_idx, + tokens=video_placeholder.tokens, + is_embed=audio_is_embed, + ) + audio_placeholders.append(audio_placeholder) + + result_placeholders["audio"] = audio_placeholders + return result_placeholders + + def _get_raw_input_ids( + self, + token_ids: list[int], + use_audio_in_video: bool = False, + ) -> list[int]: + tokenizer = self.info.get_tokenizer() + vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0] + vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0] + audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0] + audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0] + audio_token = tokenizer.encode("<|audio_pad|>")[0] + image_token = tokenizer.encode("<|image_pad|>")[0] + video_token = tokenizer.encode("<|video_pad|>")[0] + + result = token_ids[:] + if use_audio_in_video: + while True: + start = None + for i in range(len(result) - 1): + if result[i : i + 2] == [vision_bos_token, audio_bos_token]: + start = i + break + if start is not None: + end = None + for i in range(start + 2, len(result) - 1): + if result[i : i + 2] == [audio_eos_token, vision_eos_token]: + end = i + break + if end is not None: + result = result[:start] + [vision_bos_token, video_token, vision_eos_token] + result[end + 2 :] + else: + break + + for mm_token in [audio_token, image_token, video_token]: + compressed = [] + for x in result: + if x != mm_token or (not compressed or compressed[-1] != mm_token): + compressed.append(x) + result = compressed + + return result + + +class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin): + def _process_audio_input( + self, + audio_input: Qwen2_5OmniAudioFeatureInputs, + audio_hashes: list[str] | None = None, + cached_audio_features: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, ...]: + input_features = audio_input["input_features"] + audio_feature_lengths = audio_input["audio_feature_lengths"] + + audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths) + + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_output_lengths, + ) + # OMNI: audio_tower.forward() returns hidden_states tensor directly + audio_features = audio_outputs + return audio_features.split(audio_output_lengths.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3OmniMoeThinkerMultiModalProcessor, + info=Qwen3OmniMoeThinkerProcessingInfo, + dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder, +) +class Qwen3OmniMoeThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsMRoPE, + Qwen3OmniMoeConditionalGenerationMixin, +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + } + ) + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + if modality.startswith("audio"): + return "<|audio_start|><|audio_pad|><|audio_end|>" + + raise ValueError("Only image, video or audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config # needed for torch compile forward context + thinker_config: Qwen3OmniMoeThinkerConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.multimodal_config = multimodal_config + + self.audio_tower = Qwen3OmniMoeAudioEncoder( + thinker_config.audio_config, + ) + + self.visual = Qwen3Omni_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.quant_config = quant_config + + self.language_model = Qwen3MoeLLMForCausalLM( + vllm_config=vllm_config.with_hf_config(thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"]), + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors + + self.use_deepstack = hasattr(thinker_config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(thinker_config.vision_config.deepstack_visual_indexes) if self.use_deepstack else 0 + ) + # register buffer for deepstack + self.deepstack_input_embeds = ( + [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + thinker_config.text_config.hidden_size, + ) + for _ in range(self.deepstack_num_level) + ] + if self.use_deepstack + else None + ) + self.visual_dim = thinker_config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + return IntermediateTensors( + { + f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][:num_tokens] + for idx in range(self.deepstack_num_level) + } + ) + + def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: + # set deepstack_input_embeds to buffer + num_tokens = deepstack_input_embeds.size(1) + if num_tokens > self.deepstack_input_embeds[0].size(0): + self.deepstack_input_embeds = [ + torch.zeros( + num_tokens, + self.config.text_config.hidden_size, + device=self.deepstack_input_embeds[0].device, + dtype=self.deepstack_input_embeds[0].dtype, + ) + for _ in range(self.deepstack_num_level) + ] + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].copy_(deepstack_input_embeds[idx]) + + def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].zero_() + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds") and "image" not in mm_input_by_modality: + mm_input_by_modality["image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds") and "video" not in mm_input_by_modality: + mm_input_by_modality["video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features") and "audio" not in mm_input_by_modality: + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += tuple(video_embeddings) + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += tuple(audio_embeddings) + return multimodal_embeddings + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._embed_text_input_ids( + input_ids, + self.language_model.embed_input_ids, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + deepstack_input_embeds = None + # split the feat dim to obtain multi-scale visual feature + has_vision_embeddings = [ + embeddings.shape[-1] != self.config.text_config.hidden_size for embeddings in multimodal_embeddings + ] + if self.visual.deepstack_visual_indexes is not None and any(has_vision_embeddings): + multiscale_len = len(self.visual.deepstack_visual_indexes) + multimodal_embeddings_multiscale = [] + is_vision = torch.zeros_like(is_multimodal) + mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0] + mm_position_idx = 0 + for index, embeddings in enumerate(multimodal_embeddings): + num_tokens = embeddings.shape[0] + current_positions = mm_positions[mm_position_idx : mm_position_idx + num_tokens] + + # Vision embeddings + if embeddings.shape[-1] != self.config.text_config.hidden_size: + visual_dim = embeddings.shape[-1] // (multiscale_len + 1) + multi_dim = visual_dim * multiscale_len + embeddings_main, embeddings_multiscale = torch.split(embeddings, [visual_dim, multi_dim], dim=-1) + multimodal_embeddings[index] = embeddings_main + multimodal_embeddings_multiscale.append(embeddings_multiscale) + is_vision[current_positions] = True + + # Audio embeddings + else: + is_vision[current_positions] = False + + mm_position_idx += num_tokens + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1) + ) + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_vision, + ) + deepstack_input_embeds = ( + deepstack_input_embeds.view(inputs_embeds.shape[0], multiscale_len, visual_dim) + .permute(1, 0, 2) + .contiguous() + ) + self._set_deepstack_input_embeds(deepstack_input_embeds) + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + capture_layer_indices: Sequence[int] | None = None, + return_hidden_states: bool = False, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + if self.use_deepstack and inputs_embeds is not None and get_pp_group().is_first_rank: + deepstack_input_embeds = self._get_deepstack_input_embeds(inputs_embeds.size(0)) + else: + deepstack_input_embeds = None + + hidden_states, captured_hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + capture_layer_indices=capture_layer_indices, + return_hidden_states=return_hidden_states, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + if inputs_embeds is not None and get_pp_group().is_first_rank: + self._clear_deepstack_input_embeds(inputs_embeds.size(0)) + + return hidden_states, captured_hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "code2wav."], + ) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + return loaded_weights + + def get_mrope_input_positions( + self, + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + ) -> tuple[torch.Tensor, int]: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(image_grid_thw) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(video_grid_thw) + + input_ids = torch.tensor(input_tokens) + if input_ids is None or input_ids.ndim != 1: + raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") + + seq_len = input_ids.shape[0] + + if isinstance(audio_feature_lengths, list): + audio_feature_lengths = torch.tensor(audio_feature_lengths, dtype=torch.long) + + if not len(second_per_grid_ts) and len(video_grid_thw): + second_per_grid_ts = 2.0 + second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32) * second_per_grid_ts + else: + second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) + + config = self.config + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + audio_token_id = config.audio_token_id + vision_start_token_id = config.vision_start_token_id + audio_start_token_id = config.audio_start_token_id + position_id_per_seconds = config.position_id_per_seconds + + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + if vision_start_indices.numel() > 0: + vision_tokens = input_ids[vision_start_indices + 1] + else: + vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype) + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + + llm_pos_ids_list: list[torch.Tensor] = [] + st = 0 + image_idx = 0 + video_idx = 0 + audio_idx = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501 + multimodal_nums = image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums # noqa: E501 + + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + if (image_token_id in input_tokens or video_token_id in input_tokens) and ( + remain_videos > 0 or remain_images > 0 + ): + ed_vision_start = input_tokens.index(vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index(audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + + if min_ed == ed_audio_start: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append(torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx]) + llm_pos_ids = torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append(torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = torch.arange(grid_t) * position_id_per_seconds + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == video_token_id + and not use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append(torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = torch.arange(grid_t) * float(second_per_grids[video_idx].item()) * position_id_per_seconds + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start and use_audio_in_video: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append(torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + bos_block = torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(bos_block) + llm_pos_ids_list.append(bos_block) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx]) + audio_llm_pos_ids = torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = torch.arange(grid_t) * float(second_per_grids[video_idx].item()) * position_id_per_seconds + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_data_index, audio_data_index = 0, 0 + while video_data_index < video_llm_pos_ids.shape[-1] and audio_data_index < audio_llm_pos_ids.shape[-1]: + if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]: + llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1]) + video_data_index += 1 + else: + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1]) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + eos_block = torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(eos_block) + llm_pos_ids_list.append(eos_block) + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501 + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + if llm_positions.shape[1] != seq_len: + raise RuntimeError("Position ids length mismatch with input ids length") + + mrope_position_delta = llm_positions.max() + 1 - seq_len + return llm_positions, mrope_position_delta + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.merger", + tower_model=["visual.", "audio_tower."], + ) diff --git a/vllm_omni/model_executor/models/qwen3_tts/__init__.py b/vllm_omni/model_executor/models/qwen3_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..dde690068658deb4a7c7cd8144ec5a1b845b2cc2 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py @@ -0,0 +1,523 @@ +# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from transformers.configuration_utils import PretrainedConfig, layer_type_validation +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen3TTSSpeakerEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3TTSSpeakerEncoder`]. + It is used to instantiate a Qwen3TTS speaker encoder model according to the specified arguments, defining the model + architecture. The architecture is based on the ECAPA-TDNN model. + + Args: + mel_dim (`int`, *optional*, defaults to 128): + The dimension of the input mel-spectrogram. + enc_dim (`int`, *optional*, defaults to 192): + The dimension of the final speaker embedding. + enc_channels (`list[int]`, *optional*, defaults to `[512, 512, 512, 512, 1536]`): + A list of output channels for each TDNN/SERes2Net layer in the encoder. + The first channel size is for the initial TDNN layer, + the intermediate ones for the `SqueezeExcitationRes2NetBlock` layers, + and the last one for the multi-layer feature aggregation. + enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`): + A list of kernel sizes for each layer in the encoder, corresponding to `enc_channels`. + enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`): + A list of dilations for each layer in the encoder, corresponding to `enc_channels`. + enc_attention_channels (`int`, *optional*, defaults to 128): + The number of attention channels in the `AttentiveStatisticsPooling` layer. + enc_res2net_scale (`int`, *optional*,defaults to 8): + The scale of the `Res2NetBlock` in the encoder. + enc_se_channels (`int`, *optional*, defaults to 128): + The number of channels in the squeeze part of the `SqueezeExcitationBlock`. + """ + + def __init__( + self, + mel_dim=128, + enc_dim=1024, + enc_channels=[512, 512, 512, 512, 1536], + enc_kernel_sizes=[5, 3, 3, 3, 1], + enc_dilations=[1, 2, 3, 4, 1], + enc_attention_channels=128, + enc_res2net_scale=8, + enc_se_channels=128, + sample_rate=24000, + ): + self.mel_dim = mel_dim + self.enc_dim = enc_dim + self.enc_channels = enc_channels + self.enc_kernel_sizes = enc_kernel_sizes + self.enc_dilations = enc_dilations + self.enc_attention_channels = enc_attention_channels + self.enc_res2net_scale = enc_res2net_scale + self.enc_se_channels = enc_se_channels + self.sample_rate = sample_rate + + +class Qwen3TTSTalkerCodePredictorConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3TTSTalkerCodePredictorModel`]. + It is used to instantiate a Qwen3TTSTalkerCodePredictor model according to the specified arguments, + defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3TTSTalkerCodePredictor model. + Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3TTSTalkerCodePredictorModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. + The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + layer_types (`list`, *optional*): + Attention pattern for each layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + """ + + model_type = "qwen3_tts_talker_code_predictor" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3TTSTalkerCodePredictor` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=2048, + hidden_size=1024, + intermediate_size=3072, + num_hidden_layers=5, + num_attention_heads=16, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=0.000001, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + layer_types=None, + attention_dropout=0, + num_code_groups=32, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + self.num_code_groups = num_code_groups + + +class Qwen3TTSTalkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3TTSTalkerModel`]. It is used to instantiate a + Qwen3TTSTalker model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3TTSTalker model. + Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3TTSTalkerModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + """ + + model_type = "qwen3_tts_talker" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3TTSTalker` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + sub_configs = {"code_predictor_config": Qwen3TTSTalkerCodePredictorConfig} + + def __init__( + self, + code_predictor_config=None, + vocab_size=3072, + hidden_size=1024, + intermediate_size=2048, + num_hidden_layers=20, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=0.000001, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + attention_dropout=0, + num_code_groups=32, + text_hidden_size=2048, + codec_eos_token_id=4198, + codec_think_id=4202, + codec_nothink_id=4203, + codec_think_bos_id=4204, + codec_think_eos_id=4205, + codec_pad_id=4196, + codec_bos_id=4197, + spk_id=None, + spk_is_dialect=None, + codec_language_id=None, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + if code_predictor_config is None: + code_predictor_config = {} + self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig() + logger.info("code_predictor_config is None. Initializing code_predictor model with default values") + elif isinstance(code_predictor_config, Qwen3TTSTalkerCodePredictorConfig): + self.code_predictor_config = code_predictor_config + else: + self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig(**code_predictor_config) + self.num_code_groups = num_code_groups + self.text_hidden_size = text_hidden_size + self.codec_eos_token_id = codec_eos_token_id + self.codec_think_id = codec_think_id + self.codec_language_id = codec_language_id + self.codec_nothink_id = codec_nothink_id + self.codec_think_bos_id = codec_think_bos_id + self.codec_think_eos_id = codec_think_eos_id + self.codec_pad_id = codec_pad_id + self.codec_bos_id = codec_bos_id + self.spk_id = spk_id + self.spk_is_dialect = spk_is_dialect + + +class Qwen3TTSConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen3TTSForConditionalGeneration`]. + """ + + model_type = "qwen3_tts" + sub_configs = { + "talker_config": Qwen3TTSTalkerConfig, + "speaker_encoder_config": Qwen3TTSSpeakerEncoderConfig, + } + + def __init__( + self, + talker_config=None, + speaker_encoder_config=None, + tokenizer_type=None, + tts_model_size=None, + tts_model_type=None, + im_start_token_id=151644, + im_end_token_id=151645, + tts_pad_token_id=151671, + tts_bos_token_id=151672, + tts_eos_token_id=151673, + **kwargs, + ): + super().__init__(**kwargs) + + if talker_config is None: + talker_config = {} + logger.info("talker_config is None. Initializing talker model with default values") + if speaker_encoder_config is None: + speaker_encoder_config = {} + logger.info("speaker_encoder_config is None. Initializing talker model with default values") + + self.talker_config = Qwen3TTSTalkerConfig(**talker_config) + self.speaker_encoder_config = Qwen3TTSSpeakerEncoderConfig(**speaker_encoder_config) + + self.tokenizer_type = tokenizer_type + self.tts_model_size = tts_model_size + self.tts_model_type = tts_model_type + + self.im_start_token_id = im_start_token_id + self.im_end_token_id = im_end_token_id + self.tts_pad_token_id = tts_pad_token_id + self.tts_bos_token_id = tts_bos_token_id + self.tts_eos_token_id = tts_eos_token_id + + # TODO: remove these dummy values after + self.image_token_id = 0 # dummy image token id + self.video_token_id = 0 # dummy video token id + self.vision_start_token_id = 0 # dummy vision start token id + self.vision_config = PretrainedConfig() # dummy vision config + self.vision_config.spatial_merge_size = 1 + + def get_text_config(self, **kwargs): + # vLLM expects text config to expose hidden_size/num_attention_heads. + # For Qwen3 TTS, the talker config is the text model config. + config = self.talker_config + # if hasattr(config, "rope_parameters"): + # delattr(config, "rope_parameters") + return config + + +__all__ = ["Qwen3TTSConfig", "Qwen3TTSTalkerConfig", "Qwen3TTSSpeakerEncoderConfig"] diff --git a/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..1e759a8d2b43db92bb22ee236ca161a86953f737 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py @@ -0,0 +1,2326 @@ +# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3TTS model.""" + +import json +import os +from collections.abc import Callable +from dataclasses import dataclass + +import torch +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + ModelOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import can_return_tuple, logging +from transformers.utils.hub import cached_file + +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +from .configuration_qwen3_tts import ( + Qwen3TTSConfig, + Qwen3TTSSpeakerEncoderConfig, + Qwen3TTSTalkerCodePredictorConfig, + Qwen3TTSTalkerConfig, +) +from .qwen3_tts_tokenizer import Qwen3TTSTokenizer + +logger = logging.get_logger(__name__) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class Qwen3TTSSpeakerEncoder(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + Use for Qwen3TTS extract speaker embedding. + """ + + def __init__(self, config: Qwen3TTSSpeakerEncoderConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +def dynamic_range_compression_torch(x, c=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * c) + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sampling_rate: int, + hop_size: int, + win_size: int, + fmin: int, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """ + Calculate the mel spectrogram of an input signal. + This function uses slaney norm for the librosa mel filterbank + (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). + + Args: + y (torch.Tensor): Input signal. + n_fft (int): FFT size. + num_mels (int): Number of mel bins. + sampling_rate (int): Sampling rate of the input signal. + hop_size (int): Hop size for STFT. + win_size (int): Window size for STFT. + fmin (int): Minimum frequency for mel filterbank. + fmax (int): Maximum frequency for mel filterbank. + If None, defaults to half the sampling rate (fmax = sr / 2.0) + inside librosa_mel_fn + center (bool): Whether to pad the input to center the frames. Default is False. + + Returns: + torch.Tensor: Mel spectrogram. + """ + if torch.min(y) < -1.0: + print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") + if torch.max(y) > 1.0: + print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") + + device = y.device + + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + + mel_basis = torch.from_numpy(mel).float().to(device) + hann_window = torch.hann_window(win_size).to(device) + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = dynamic_range_compression_torch(mel_spec) + + return mel_spec + + +def _compute_default_rope_parameters( + config, + device, +): + base = config.rope_theta + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, attention_factor + + +class Qwen3TTSPreTrainedModel(PreTrainedModel): + config_class = Qwen3TTSConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3TTSDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = False + _supports_attention_backend = True + + def _init_weights(self, module): + # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 + + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + if module.weight is not None: + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + + +class Qwen3TTSTalkerTextPreTrainedModel(PreTrainedModel): + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen3TTSRMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen3TTSTalkerRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3TTSTalkerConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn: Callable = _compute_default_rope_parameters + if self.rope_type != "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3TTSThinkerText has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3TTSRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3TTSConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn: Callable = _compute_default_rope_parameters + if self.rope_type != "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3TTSRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3TTSRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, mrope_interleaved=False, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + if mrope_interleaved: + + def apply_interleaved_rope(x, modality_num): + x_t = x[0].clone() + index_ranges = [] + for i, n in enumerate(mrope_section[1:], 1): + beg_idx = i + end_idx = n * modality_num + index_ranges.append((beg_idx, end_idx)) + for beg_idx, end_idx in index_ranges: + x_t[..., beg_idx:end_idx:modality_num] = x[beg_idx, ..., beg_idx:end_idx:modality_num] + return x_t + + dim = cos.shape[-1] + modality_num = len(mrope_section) + cos = torch.cat([apply_interleaved_rope(cos[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([apply_interleaved_rope(sin[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze( + unsqueeze_dim + ) + else: + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3TTSTalkerAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = getattr(config, "sliding_window", None) + self.rope_scaling = config.rope_scaling + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"], self.rope_scaling["interleaved"] + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3TTSTalkerResizeMLP(nn.Module): + def __init__(self, input_size: int, intermediate_size: int, output_size: int, act: str, bias=False): + super().__init__() + self.linear_fc1 = nn.Linear(input_size, intermediate_size, bias=bias) + self.linear_fc2 = nn.Linear(intermediate_size, output_size, bias=bias) + self.act_fn = ACT2FN[act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +@dataclass +class Qwen3TTSTalkerCodePredictorOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head + (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, + returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + generation_steps: int | None = None + + +class Qwen3TTSTalkerTextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3TTSAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3TTSConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3TTSDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3TTSConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3TTSAttention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3TTSTalkerTextMLP(config) + self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Qwen3TTSTalkerCodePredictorModel(Qwen3TTSPreTrainedModel): + config_class = Qwen3TTSTalkerCodePredictorConfig + base_model_prefix = "talker.code_predictor.model" + + def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, embedding_dim: int): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.layers = nn.ModuleList( + [Qwen3TTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3TTSRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + self.codec_embedding = nn.ModuleList( + [nn.Embedding(config.vocab_size, embedding_dim) for _ in range(config.num_code_groups - 1)] + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.codec_embedding + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + generation_steps=None, + **flash_attn_kwargs, + ) -> BaseModelOutputWithPast: + if input_ids is not None: + raise ValueError("`input_ids` is expected to be `None`") + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Qwen3TTSTalkerCodePredictorConfig + base_model_prefix = "talker.code_predictor" + + def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, talker_config: Qwen3TTSTalkerConfig): + super().__init__(config) + self.model = Qwen3TTSTalkerCodePredictorModel(config, talker_config.hidden_size) + self.vocab_size = config.vocab_size + self.lm_head = nn.ModuleList( + [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] + ) + + if config.hidden_size != talker_config.hidden_size: + self.small_to_mtp_projection = torch.nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True) + else: + self.small_to_mtp_projection = torch.nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward_finetune( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + generation_steps=None, + **kwargs, + ) -> CausalLMOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + logits = [] + for i in range(1, self.config.num_code_groups): + logits.append(self.lm_head[i - 1](hidden_states[:, i])) + logits = torch.stack(logits, dim=1) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return Qwen3TTSTalkerCodePredictorOutputWithPast(loss=loss, logits=logits) + + @can_return_tuple + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + generation_steps=None, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Prefill stage + if inputs_embeds is not None and inputs_embeds.shape[1] > 1: + generation_steps = inputs_embeds.shape[1] - 2 # hidden & layer 0 + # Generation stage + else: + inputs_embeds = self.model.get_input_embeddings()[generation_steps - 1](input_ids) + inputs_embeds = self.small_to_mtp_projection(inputs_embeds) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head[generation_steps](hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return Qwen3TTSTalkerCodePredictorOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + generation_steps=generation_steps + 1, + ) + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1): + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + model_kwargs["generation_steps"] = outputs.generation_steps + return model_kwargs + + +@dataclass +class Qwen3TTSTalkerOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head + (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, + returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + past_hidden: torch.FloatTensor | None = None + generation_step: int | None = None + trailing_text_hidden: torch.FloatTensor | None = None + tts_pad_embed: torch.FloatTensor | None = None + + +class Qwen3TTSTalkerDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3TTSTalkerAttention(config, layer_idx) + + self.mlp = Qwen3TTSTalkerTextMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Qwen3TTSTalkerModel(Qwen3TTSTalkerTextPreTrainedModel): + config_class = Qwen3TTSTalkerConfig + base_model_prefix = "talker.model" + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.layers = nn.ModuleList( + [Qwen3TTSTalkerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3TTSTalkerRotaryEmbedding(config) + self.gradient_checkpointing = False + self.codec_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.text_embedding = nn.Embedding(config.text_vocab_size, config.text_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.codec_embedding + + def get_text_embeddings(self): + return self.text_embedding + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Qwen3TTSTalkerForConditionalGeneration(Qwen3TTSTalkerTextPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Qwen3TTSTalkerConfig + base_model_prefix = "talker" + + def __init__(self, config: Qwen3TTSTalkerConfig): + super().__init__(config) + self.model = Qwen3TTSTalkerModel(config) + self.vocab_size = config.vocab_size + self.text_projection = Qwen3TTSTalkerResizeMLP( + config.text_hidden_size, config.text_hidden_size, config.hidden_size, config.hidden_act, bias=True + ) + + self.codec_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.code_predictor = Qwen3TTSTalkerCodePredictorModelForConditionalGeneration( + config=config.code_predictor_config, talker_config=config + ) + self.rope_deltas = None + + # Initialize weights and apply final processing + self.post_init() + + # TODO: hack, modular cannot inherit multiple classes + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def get_text_embeddings(self): + return self.model.get_text_embeddings() + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward_sub_talker_finetune(self, codec_ids, talker_hidden_states): + assert len(codec_ids.shape) == 2 + assert len(talker_hidden_states.shape) == 2 + assert codec_ids.shape[0] == talker_hidden_states.shape[0] + assert talker_hidden_states.shape[1] == self.config.hidden_size + assert codec_ids.shape[1] == self.config.num_code_groups + + sub_talker_inputs_embeds = [talker_hidden_states.unsqueeze(1)] + + for i in range(self.config.num_code_groups - 1): + if i == 0: + sub_talker_inputs_embeds.append(self.get_input_embeddings()(codec_ids[:, :1])) + else: + sub_talker_inputs_embeds.append( + self.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, i : i + 1]) + ) + sub_talker_inputs_embeds = torch.cat(sub_talker_inputs_embeds, dim=1) + + sub_talker_outputs = self.code_predictor.forward_finetune( + inputs_embeds=sub_talker_inputs_embeds, labels=codec_ids[:, 1:] + ) + + sub_talker_logits = sub_talker_outputs.logits + sub_talker_loss = sub_talker_outputs.loss + return sub_talker_logits, sub_talker_loss + + @can_return_tuple + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + past_hidden=None, + trailing_text_hidden=None, + tts_pad_embed=None, + generation_step=None, + subtalker_dosample=None, + subtalker_top_p=None, + subtalker_top_k=None, + subtalker_temperature=None, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + # Prefill + if inputs_embeds is not None and inputs_embeds.shape[1] > 1: + generation_step = -1 + codec_ids = None + # Generate + else: + last_id_hidden = self.get_input_embeddings()(input_ids) + predictor_result = self.code_predictor.generate( + inputs_embeds=torch.cat((past_hidden, last_id_hidden), dim=1), + max_new_tokens=self.config.num_code_groups - 1, + do_sample=subtalker_dosample, + top_p=subtalker_top_p, + top_k=subtalker_top_k, + temperature=subtalker_temperature, + output_hidden_states=True, + return_dict_in_generate=True, + ) + codec_ids = torch.cat((input_ids, predictor_result.sequences), dim=-1) + codec_hiddens = torch.cat( + [last_id_hidden] + + [ + self.code_predictor.get_input_embeddings()[i](predictor_result.sequences[..., i : i + 1]) + for i in range(self.config.num_code_groups - 1) + ], + dim=1, + ) + inputs_embeds = codec_hiddens.sum(1, keepdim=True) + + if generation_step < trailing_text_hidden.shape[1]: + inputs_embeds = inputs_embeds + trailing_text_hidden[:, generation_step].unsqueeze(1) + else: + inputs_embeds = inputs_embeds + tts_pad_embed + if attention_mask is not None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) + position_ids, rope_deltas = self.get_rope_index( + attention_mask, + ) + rope_deltas = rope_deltas - delta0 + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs: BaseModelOutputWithPast = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.codec_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return Qwen3TTSTalkerOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=(outputs.hidden_states, codec_ids), + attentions=outputs.attentions, + past_hidden=hidden_states[:, -1:, :], + generation_step=generation_step + 1, + trailing_text_hidden=trailing_text_hidden, + tts_pad_embed=tts_pad_embed, + ) + + def get_rope_index( + self, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. + This means one frame is processed each second. + interval: The step size for the temporal position IDs, + calculated as tokens_per_second * temporal_patch_size / fps. + In this case, 25 * 2 / 1 = 50. This means that each temporal + patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + mrope_position_deltas = [] + + position_ids = attention_mask.float().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1): + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + model_kwargs["past_hidden"] = outputs.past_hidden + model_kwargs["generation_step"] = outputs.generation_step + model_kwargs["trailing_text_hidden"] = outputs.trailing_text_hidden + model_kwargs["tts_pad_embed"] = outputs.tts_pad_embed + return model_kwargs + + +class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin): + config_class = Qwen3TTSConfig + + def __init__(self, config: Qwen3TTSConfig): + super().__init__(config) + self.config = config + + self.talker = Qwen3TTSTalkerForConditionalGeneration(self.config.talker_config) + + if config.tts_model_type == "base": + self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config) + else: + self.speaker_encoder = None + + self.speech_tokenizer = None + self.generate_config = None + + self.supported_speakers = self.config.talker_config.spk_id.keys() + self.supported_languages = ["auto"] + for language_id in self.config.talker_config.codec_language_id.keys(): + if "dialect" not in language_id: + self.supported_languages.append(language_id) + + self.speaker_encoder_sample_rate = self.config.speaker_encoder_config.sample_rate + self.tokenizer_type = self.config.tokenizer_type + self.tts_model_size = self.config.tts_model_size + self.tts_model_type = self.config.tts_model_type + + self.post_init() + + def load_speech_tokenizer(self, speech_tokenizer): + self.speech_tokenizer = speech_tokenizer + + def load_generate_config(self, generate_config): + self.generate_config = generate_config + + def get_supported_speakers(self): + return self.supported_speakers + + def get_supported_languages(self): + return self.supported_languages + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + cache_dir=None, + ignore_mismatched_sizes=False, + force_download=False, + local_files_only=False, + token=None, + revision="main", + use_safetensors=None, + weights_only=True, + **kwargs, + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + if not local_files_only and not os.path.isdir(pretrained_model_name_or_path): + download_cache_dir = kwargs.get("cache_dir", cache_dir) + download_revision = kwargs.get("revision", revision) + download_weights_from_hf_specific( + pretrained_model_name_or_path, + cache_dir=download_cache_dir, + allow_patterns=["speech_tokenizer/*"], + revision=download_revision, + ) + speech_tokenizer_path = cached_file( + pretrained_model_name_or_path, + "speech_tokenizer/config.json", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + if speech_tokenizer_path is None: + raise ValueError(f"""{pretrained_model_name_or_path}/{speech_tokenizer_path} not exists""") + speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path) + speech_tokenizer = Qwen3TTSTokenizer.from_pretrained( + speech_tokenizer_dir, + *model_args, + **kwargs, + ) + model.load_speech_tokenizer(speech_tokenizer) + + generate_config_path = cached_file( + pretrained_model_name_or_path, + "generation_config.json", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + with open(generate_config_path, encoding="utf-8") as f: + generate_config = json.load(f) + model.load_generate_config(generate_config) + + return model + + @torch.inference_mode() + def extract_speaker_embedding(self, audio, sr): + assert sr == 24000, "Only support 24kHz audio" + mels = mel_spectrogram( + torch.from_numpy(audio).unsqueeze(0), + n_fft=1024, + num_mels=128, + sampling_rate=24000, + hop_size=256, + win_size=1024, + fmin=0, + fmax=12000, + ).transpose(1, 2) + speaker_embedding = self.speaker_encoder(mels.to(self.device).to(self.dtype))[0] + return speaker_embedding + + @torch.inference_mode() + def generate_speaker_prompt(self, voice_clone_prompt: list[dict]): + voice_clone_spk_embeds = [] + for index in range(len(voice_clone_prompt["ref_spk_embedding"])): + ref_spk_embedding = ( + voice_clone_prompt["ref_spk_embedding"][index].to(self.talker.device).to(self.talker.dtype) + ) + voice_clone_spk_embeds.append(ref_spk_embedding) + + return voice_clone_spk_embeds + + def generate_icl_prompt( + self, + text_id: torch.Tensor, + ref_id: torch.Tensor, + ref_code: torch.Tensor, + tts_pad_embed: torch.Tensor, + tts_eos_embed: torch.Tensor, + non_streaming_mode: bool, + ): + # text embed (ref id + text id + eos) 1 T1 D + text_embed = self.talker.text_projection( + self.talker.get_text_embeddings()(torch.cat([ref_id, text_id], dim=-1)) + ) + text_embed = torch.cat([text_embed, tts_eos_embed], dim=1) + # codec embed (codec bos + codec) 1 T2 D + codec_embed = [] + for i in range(self.talker.config.num_code_groups): + if i == 0: + codec_embed.append(self.talker.get_input_embeddings()(ref_code[:, :1])) + else: + codec_embed.append(self.talker.code_predictor.get_input_embeddings()[i - 1](ref_code[:, i : i + 1])) + codec_embed = torch.cat(codec_embed, dim=1).sum(1).unsqueeze(0) + codec_embed = torch.cat( + [ + self.talker.get_input_embeddings()( + torch.tensor( + [ + [ + self.config.talker_config.codec_bos_id, + ] + ], + device=self.talker.device, + dtype=text_id.dtype, + ) + ), + codec_embed, + ], + dim=1, + ) + # compute lens + text_lens = text_embed.shape[1] + codec_lens = codec_embed.shape[1] + if non_streaming_mode: + icl_input_embed = text_embed + self.talker.get_input_embeddings()( + torch.tensor( + [ + [ + self.config.talker_config.codec_pad_id, + ] + * text_lens + ], + device=self.talker.device, + dtype=text_id.dtype, + ) + ) + icl_input_embed = torch.cat([icl_input_embed, codec_embed + tts_pad_embed], dim=1) + return icl_input_embed, tts_pad_embed + else: + if text_lens > codec_lens: + return text_embed[:, :codec_lens] + codec_embed, text_embed[:, codec_lens:] + else: + text_embed = torch.cat([text_embed] + [tts_pad_embed] * (codec_lens - text_lens), dim=1) + return text_embed + codec_embed, tts_pad_embed + + @torch.no_grad() + def generate( + self, + input_ids: list[torch.Tensor] | None = None, + instruct_ids: list[torch.Tensor] | None = None, + ref_ids: list[torch.Tensor] | None = None, + voice_clone_prompt: list[dict] = None, + languages: list[str] = None, + speakers: list[str] = None, + non_streaming_mode=False, + max_new_tokens: int = 4096, + do_sample: bool = True, + top_k: int = 50, + top_p: float = 1.0, + temperature: float = 0.9, + subtalker_dosample: bool = True, + subtalker_top_k: int = 50, + subtalker_top_p: float = 1.0, + subtalker_temperature: float = 0.9, + eos_token_id: int | None = None, + repetition_penalty: float = 1.05, + **kwargs, + ): + talker_kwargs = { + "max_new_tokens": max_new_tokens, + "min_new_tokens": 2, + "do_sample": do_sample, + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "subtalker_dosample": subtalker_dosample, + "subtalker_top_k": subtalker_top_k, + "subtalker_top_p": subtalker_top_p, + "subtalker_temperature": subtalker_temperature, + "eos_token_id": eos_token_id if eos_token_id is not None else self.config.talker_config.codec_eos_token_id, + "repetition_penalty": repetition_penalty, + "suppress_tokens": [ + i + for i in range(self.config.talker_config.vocab_size - 1024, self.config.talker_config.vocab_size) + if i not in (self.config.talker_config.codec_eos_token_id,) + ], + "output_hidden_states": getattr(kwargs, "output_hidden_states", True), + "return_dict_in_generate": getattr(kwargs, "return_dict_in_generate", True), + } + + talker_input_embeds = [[] for _ in range(len(input_ids))] + + voice_clone_spk_embeds = None + # voice clone speaker prompt generate + if voice_clone_prompt is not None: + voice_clone_spk_embeds = self.generate_speaker_prompt(voice_clone_prompt) + + # instruct text prompt generate + if instruct_ids is not None: + for index, instruct_id in enumerate(instruct_ids): + if instruct_id is not None: + talker_input_embeds[index].append( + self.talker.text_projection(self.talker.get_text_embeddings()(instruct_id)) + ) + + # tts text prompt generate + trailing_text_hiddens = [] + if speakers is None: + speakers = [None] * len(input_ids) + for index, (input_id, language, speaker) in enumerate(zip(input_ids, languages, speakers)): + if voice_clone_spk_embeds is None: + if speaker == "" or speaker is None: # Instruct create speaker + speaker_embed = None + else: + if speaker.lower() not in self.config.talker_config.spk_id: + raise NotImplementedError(f"Speaker {speaker} not implemented") + else: + spk_id = self.config.talker_config.spk_id[speaker.lower()] + speaker_embed = self.talker.get_input_embeddings()( + torch.tensor( + spk_id, + device=self.talker.device, + dtype=input_id.dtype, + ) + ) + else: + if voice_clone_prompt["x_vector_only_mode"][index] or voice_clone_prompt["icl_mode"][index]: + speaker_embed = voice_clone_spk_embeds[index] + else: + speaker_embed = None + + assert language is not None + + if language.lower() == "auto": + language_id = None + else: + if language.lower() not in self.config.talker_config.codec_language_id: + raise NotImplementedError(f"Language {language} not implemented") + else: + language_id = self.config.talker_config.codec_language_id[language.lower()] + + if ( + language.lower() in ["chinese", "auto"] + and speaker != "" + and speaker is not None + and self.config.talker_config.spk_is_dialect[speaker.lower()] is not False + ): + dialect = self.config.talker_config.spk_is_dialect[speaker.lower()] + language_id = self.config.talker_config.codec_language_id[dialect] + + tts_bos_embed, tts_eos_embed, tts_pad_embed = self.talker.text_projection( + self.talker.get_text_embeddings()( + torch.tensor( + [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], + device=self.talker.device, + dtype=input_id.dtype, + ) + ) + ).chunk(3, dim=1) # 3 * [1 1 d] + + # codec: tag and speaker + if language_id is None: + codec_prefill_list = [ + [ + self.config.talker_config.codec_nothink_id, + self.config.talker_config.codec_think_bos_id, + self.config.talker_config.codec_think_eos_id, + ] + ] + else: + codec_prefill_list = [ + [ + self.config.talker_config.codec_think_id, + self.config.talker_config.codec_think_bos_id, + language_id, + self.config.talker_config.codec_think_eos_id, + ] + ] + + codec_input_emebdding_0 = self.talker.get_input_embeddings()( + torch.tensor( + codec_prefill_list, + device=self.talker.device, + dtype=input_id.dtype, + ) + ) + codec_input_emebdding_1 = self.talker.get_input_embeddings()( + torch.tensor( + [ + [ + self.config.talker_config.codec_pad_id, + self.config.talker_config.codec_bos_id, + ] + ], + device=self.talker.device, + dtype=input_id.dtype, + ) + ) + if speaker_embed is None: + codec_input_emebdding = torch.cat([codec_input_emebdding_0, codec_input_emebdding_1], dim=1) + else: + codec_input_emebdding = torch.cat( + [codec_input_emebdding_0, speaker_embed.view(1, 1, -1), codec_input_emebdding_1], dim=1 + ) + + # '<|im_start|>assistant\n我叫通义千问,是阿里云的开源大模型。<|im_end|>\n<|im_start|>assistant\n' + + # <|im_start|>assistant\n + _talker_input_embed_role = self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, :3])) + + # tts_pad * 4 + tts_bos + _talker_input_embed = ( + torch.cat( + ( + tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1), + tts_bos_embed, + ), + dim=1, + ) + + codec_input_emebdding[:, :-1] + ) + + talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1) + + if ( + voice_clone_prompt is not None + and voice_clone_prompt["ref_code"] is not None + and voice_clone_prompt["icl_mode"][index] + ): + icl_input_embed, trailing_text_hidden = self.generate_icl_prompt( + text_id=input_id[:, 3:-5], + ref_id=ref_ids[index][:, 3:-2], + ref_code=voice_clone_prompt["ref_code"][index].to(self.talker.device), + tts_pad_embed=tts_pad_embed, + tts_eos_embed=tts_eos_embed, + non_streaming_mode=non_streaming_mode, + ) + talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1) + else: + # tts_text_first_token + talker_input_embed = torch.cat( + [ + talker_input_embed, + self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:4])) + + codec_input_emebdding[:, -1:], + ], + dim=1, + ) + if non_streaming_mode: + talker_input_embed = talker_input_embed[:, :-1] # 去掉原本放进去的text + talker_input_embed = torch.cat( + [ + talker_input_embed, + torch.cat( + ( + self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:-5])), + tts_eos_embed, + ), + dim=1, + ) + + self.talker.get_input_embeddings()( + torch.tensor( + [ + [ + self.config.talker_config.codec_pad_id, + ] + * (input_id[:, 3:-5].shape[1] + 1) + ], + device=self.talker.device, + dtype=input_id.dtype, + ) + ), + tts_pad_embed + + self.talker.get_input_embeddings()( + torch.tensor( + [ + [ + self.config.talker_config.codec_bos_id, + ] + ], + device=self.talker.device, + dtype=input_id.dtype, + ) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + # 叫通义千问,是阿里云的开源大模型。 + trailing_text_hidden = torch.cat( + ( + self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + talker_input_embeds[index].append(talker_input_embed) + trailing_text_hiddens.append(trailing_text_hidden) + + for index, talker_input_embed in enumerate(talker_input_embeds): + talker_input_embeds[index] = torch.cat([item for item in talker_input_embed if item is not None], dim=1) + + # for batch inferquence + original_lengths = torch.tensor([t.shape[1] for t in talker_input_embeds]) + # left padding for talker input embeds + sequences = [t.squeeze(0) for t in talker_input_embeds] + sequences_reversed = [t.flip(dims=[0]) for t in sequences] + padded_reversed = torch.nn.utils.rnn.pad_sequence(sequences_reversed, batch_first=True, padding_value=0.0) + talker_input_embeds = padded_reversed.flip(dims=[1]) + # generate mask + batch_size, max_len = talker_input_embeds.shape[0], talker_input_embeds.shape[1] + indices = torch.arange(max_len).expand(batch_size, -1) + num_pads = max_len - original_lengths + talker_attention_mask = (indices >= num_pads.unsqueeze(1)).long().to(talker_input_embeds.device) + # padding trailing text hiddens + pad_embedding_vector = tts_pad_embed.squeeze() + sequences_to_pad = [t.squeeze(0) for t in trailing_text_hiddens] + trailing_text_original_lengths = [s.shape[0] for s in sequences_to_pad] + padded_hiddens = torch.nn.utils.rnn.pad_sequence(sequences_to_pad, batch_first=True, padding_value=0.0) + arange_tensor = torch.arange(max(trailing_text_original_lengths), device=padded_hiddens.device).expand( + len(trailing_text_original_lengths), -1 + ) + lengths_tensor = torch.tensor(trailing_text_original_lengths, device=padded_hiddens.device).unsqueeze(1) + padding_mask = arange_tensor >= lengths_tensor + padded_hiddens[padding_mask] = pad_embedding_vector + trailing_text_hiddens = padded_hiddens + + # forward + talker_result = self.talker.generate( + inputs_embeds=talker_input_embeds, + attention_mask=talker_attention_mask, + trailing_text_hidden=trailing_text_hiddens, + tts_pad_embed=tts_pad_embed, + **talker_kwargs, + ) + + talker_codes = torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1) + talker_hidden_states = torch.cat([hid[0][-1][:, -1:] for hid in talker_result.hidden_states], dim=1)[:, :-1] + + first_codebook = talker_codes[:, :, 0] + is_stop_token = first_codebook == self.config.talker_config.codec_eos_token_id + stop_indices = torch.argmax(is_stop_token.int(), dim=1) + has_stop_token = is_stop_token.any(dim=1) + effective_lengths = torch.where(has_stop_token, stop_indices, talker_codes.shape[1]) + + talker_codes_list = [ + talker_codes[ + i, + :length, + ] + for i, length in enumerate(effective_lengths) + ] + talker_hidden_states_list = [talker_hidden_states[i, :length, :] for i, length in enumerate(effective_lengths)] + + return talker_codes_list, talker_hidden_states_list + + +__all__ = [ + "Qwen3TTSForConditionalGeneration", + "Qwen3TTSTalkerForConditionalGeneration", + "Qwen3TTSPreTrainedModel", + "Qwen3TTSTalkerModel", +] diff --git a/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..5643a857cdbdb34f24045653d02a32603acf7d1b --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py @@ -0,0 +1,102 @@ +# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin + + +class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + "padding_side": "left", + } + } + + +class Qwen3TTSProcessor(ProcessorMixin): + r""" + Constructs a Qwen3TTS processor. + + Args: + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The text tokenizer. + chat_template (`Optional[str]`, *optional*): + The Jinja template to use for formatting the conversation. + If not provided, the default chat template is used. + """ + + attributes = ["tokenizer"] + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, tokenizer=None, chat_template=None): + super().__init__(tokenizer, chat_template=chat_template) + + def __call__(self, text=None, **kwargs) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and audio(s). + This method forwards the `text` and `kwargs` arguments to + Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` + to encode the text. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + """ + + if text is None: + raise ValueError("You need to specify either a `text` input to process.") + + output_kwargs = self._merge_kwargs( + Qwen3TTSProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if not isinstance(text, list): + text = [text] + + texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature( + data={**texts_inputs}, + tensor_type=kwargs.get("return_tensors"), + ) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def apply_chat_template(self, conversations, chat_template=None, **kwargs): + if isinstance(conversations[0], dict): + conversations = [conversations] + return super().apply_chat_template(conversations, chat_template, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + return list(dict.fromkeys(tokenizer_input_names)) + + +__all__ = ["Qwen3TTSProcessor"] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..8514a725d4bdc0788022bac678b75c5ba7cbc16d --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py @@ -0,0 +1,1088 @@ +# Copyright 2026 The Alibaba Qwen team. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import io +import urllib.request +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any +from urllib.parse import urlparse + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel, AutoProcessor +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .configuration_qwen3_tts import Qwen3TTSConfig +from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration +from .processing_qwen3_tts import Qwen3TTSProcessor + +logger = init_logger(__name__) + +AudioLike = ( + str # wav path, URL, base64 + | np.ndarray # waveform (requires sr) + | tuple[np.ndarray, int] # (waveform, sr) +) + +MaybeList = Any | list[Any] + + +@dataclass +class VoiceClonePromptItem: + """ + Container for one sample's voice-clone prompt information that can be fed to the model. + + Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`. + """ + + ref_code: torch.Tensor | None # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz + ref_spk_embedding: torch.Tensor # (D,) + x_vector_only_mode: bool + icl_mode: bool + ref_text: str | None = None + + +class Qwen3TTSModelForGeneration(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + model_path = vllm_config.model_config.model + + # Check if flash-attn is installed + try: + import flash_attn # noqa: F401 + + attn_kwargs = {"attn_implementation": "flash_attention_2"} + except ImportError: + logger.warning("Flash-Attn is not installed. Using default PyTorch attention implementation.") + attn_kwargs = {} + + self.model = Qwen3TTSModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + **attn_kwargs, + ) + self.task_type = model_path.split("-")[-1].strip("/") + # Mark that this model produces multimodal outputs + self.have_multimodal_outputs = True + + # Store vllm_config for potential future use + self.vllm_config = vllm_config + + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: Any = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ) -> OmniOutput: + """ + Forward pass for TTS generation model. + + Args: + input_ids: Input token IDs (required for TTS generation) + positions: Position IDs (not used for TTS, but required by runner) + intermediate_tensors: Intermediate tensors for pipeline parallelism (not used) + inputs_embeds: Input embeddings (not used for TTS, but required by runner) + **kwargs: Additional arguments including task_type, sampling_metadata, etc. + + Returns: + OmniOutput: Contains multimodal outputs with audio tensors + """ + + # Extract additional parameters from kwargs that the generation methods expect + + runtime_additional_information = kwargs.get("runtime_additional_information", [{}]) + if isinstance(runtime_additional_information, list) and len(runtime_additional_information) > 0: + runtime_additional_information = runtime_additional_information[0] + text = runtime_additional_information.pop("text", [""])[0] + # Extract task_type from kwargs, default to "instruct" + task_type = runtime_additional_information.pop("task_type", [self.task_type])[0] + speaker = runtime_additional_information.pop("speaker", ["uncle_fu"])[0] + language = runtime_additional_information.pop("language", ["Auto"])[0] + instruct = runtime_additional_information.pop("instruct", [""])[0] + for key, value in runtime_additional_information.items(): + if isinstance(value, list) and len(value) > 0: + runtime_additional_information[key] = value[0] + + # During profile/warmup runs, text is empty and no real inputs exist. + # Cap generation steps so the full pipeline executes (preserving + # KV-cache profiling behaviour) but exits quickly even if the model + # cannot converge from degenerate dummy inputs. + if not text: + logger.info("Profile run detected (empty text). Capping max_new_tokens to 2.") + runtime_additional_information["max_new_tokens"] = 2 + + # Call the appropriate generation method based on task_type + if task_type == "CustomVoice": + result = self.model.generate_custom_voice( + text, speaker=speaker, language=language, instruct=instruct, **runtime_additional_information + ) + elif task_type == "VoiceDesign": + result = self.model.generate_voice_design( + text, instruct=instruct, language=language, **runtime_additional_information + ) + elif task_type == "Base": + result = self.model.generate_voice_clone(text, language=language, **runtime_additional_information) + else: + raise ValueError(f"Invalid task type: {task_type}") + + # Convert result to OmniOutput format + return self.make_omni_output(result, **kwargs) + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput | tuple, **kwargs: Any) -> OmniOutput: + """ + Make an OmniOutput object from model outputs. + Args: + model_outputs: Model outputs (either OmniOutput, tuple of (audio_tensors, sr), or tensor) + """ + if isinstance(model_outputs, OmniOutput): + return model_outputs + + # Handle tuple format: (audio_tensors, sample_rate) + if isinstance(model_outputs, tuple) and len(model_outputs) == 2: + audio_tensors, sr = model_outputs + # audio_tensors is a list of numpy arrays, convert first one to tensor if needed + if isinstance(audio_tensors, list) and len(audio_tensors) > 0: + # Convert numpy array to tensor if needed + audio_tensor = audio_tensors[0] + if isinstance(audio_tensor, np.ndarray): + audio_tensor = torch.from_numpy(audio_tensor).float() + elif not isinstance(audio_tensor, torch.Tensor): + audio_tensor = torch.tensor(audio_tensor, dtype=torch.float32) + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": audio_tensor, "sr": torch.tensor(sr, dtype=torch.int)}, + ) + + # If it's already a tensor, wrap it + if isinstance(model_outputs, torch.Tensor): + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": model_outputs}, + ) + + raise ValueError(f"Unsupported model_outputs type: {type(model_outputs)}") + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + """ + Create empty intermediate tensors for pipeline parallelism. + + For TTS generation models, pipeline parallelism is typically not used, + so this returns an empty dict. However, this method is required by the + runner infrastructure. + + Args: + batch_size: Batch size for the intermediate tensors + dtype: Data type for the tensors + device: Device for the tensors + + Returns: + IntermediateTensors: Empty dict (no PP support for TTS models) + """ + # TTS generation models typically don't use pipeline parallelism + # Return empty dict to satisfy the interface + return IntermediateTensors({}) + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Any = None, + is_multimodal: torch.Tensor | None = None, + **kwargs: Any, + ) -> torch.Tensor: + """ + Embed input token IDs into embeddings. + + This method is called by the runner when inputs_embeds are needed. + For TTS models, we typically work with input_ids directly, but this + method provides a fallback for cases where embeddings are required. + + Args: + input_ids: Input token IDs + multimodal_embeddings: Optional multimodal embeddings (not used for TTS) + is_multimodal: Optional mask indicating multimodal tokens (not used for TTS) + **kwargs: Additional arguments + + Returns: + torch.Tensor: Embedded representations of input_ids + """ + # For TTS models, we don't have a separate embedding layer exposed, + # so we return a dummy tensor. In practice, TTS models work with + # input_ids directly in the forward pass. + # This is a minimal implementation to bypass the function call. + return torch.zeros( + (input_ids.shape[0], input_ids.shape[1], 1024), # Dummy hidden size + dtype=torch.bfloat16, + device=input_ids.device, + ) + + def embed_multimodal(self, **kwargs: Any) -> Any: + """ + Embed multimodal inputs (e.g., images, audio). + + For TTS models, this is typically not used as they work with text input_ids. + This method provides a stub to satisfy the interface. + + Args: + **kwargs: Multimodal input arguments + + Returns: + None or empty list: TTS models don't use multimodal embeddings + """ + # TTS models work with text input_ids, not multimodal embeddings + # Return None to indicate no multimodal embeddings + return None + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + """Load weights into the wrapped HF model.""" + # params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + loaded_params.add(name) + + return loaded_params + + def compute_logits( + self, + hidden_states: torch.Tensor | OmniOutput, + sampling_metadata: Any = None, + ) -> torch.Tensor | None: + """Non-autoregressive TTS models do not compute token logits.""" + return None + + +class Qwen3TTSModel: + """ + A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides: + - from_pretrained() initialization via AutoModel/AutoProcessor + - generation APIs for: + * CustomVoice: generate_custom_voice() + * VoiceDesign: generate_voice_design() + * Base: generate_voice_clone() + create_voice_clone_prompt() + - consistent output: (wavs: List[np.ndarray], sample_rate: int) + + Notes: + - This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration` + - Language / speaker validation is done via model methods: + model.get_supported_languages(), model.get_supported_speakers() + """ + + def __init__( + self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: dict[str, Any] | None = None + ): + self.model = model + self.processor = processor + self.generate_defaults = generate_defaults or {} + + self.device = getattr(model, "device", None) + if self.device is None: + try: + self.device = next(model.parameters()).device + except StopIteration: + self.device = torch.device("cpu") + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + **kwargs: Any, + ) -> "Qwen3TTSModel": + """ + Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style. + + This method: + 1) Loads config via AutoConfig (so your side can register model_type -> config/model). + 2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged. + 3) Loads the processor via AutoProcessor.from_pretrained(model_path). + 4) Loads optional `generate_config.json` from the model directory/repo snapshot if present. + + Args: + pretrained_model_name_or_path (str): + HuggingFace repo id or local directory of the model. + **kwargs: + Forwarded as-is into `AutoModel.from_pretrained(...)`. + Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2". + + Returns: + Qwen3TTSModel: + Wrapper instance containing `model`, `processor`, and generation defaults. + """ + AutoConfig.register("qwen3_tts", Qwen3TTSConfig) + AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration) + AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor) + + model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) + if not isinstance(model, Qwen3TTSForConditionalGeneration): + raise TypeError(f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. ") + + processor = AutoProcessor.from_pretrained( + pretrained_model_name_or_path, + fix_mistral_regex=True, + ) + + generate_defaults = model.generate_config + return cls(model=model, processor=processor, generate_defaults=generate_defaults) + + def _supported_languages_set(self) -> set | None: + langs = getattr(self.model, "get_supported_languages", None) + if callable(langs): + v = langs() + if v is None: + return None + return set([str(x).lower() for x in v]) + return None + + def _supported_speakers_set(self) -> set | None: + spks = getattr(self.model, "get_supported_speakers", None) + if callable(spks): + v = spks() + if v is None: + return None + return set([str(x).lower() for x in v]) + return None + + def _validate_languages(self, languages: list[str]) -> None: + """ + Validate that requested languages are supported by the model. + + Args: + languages (List[str]): Language names for each sample. + + Raises: + ValueError: If any language is not supported. + """ + supported = self._supported_languages_set() + if supported is None: + return + + bad = [] + for lang in languages: + if lang is None: + bad.append(lang) + continue + if str(lang).lower() not in supported: + bad.append(lang) + if bad: + raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}") + + def _validate_speakers(self, speakers: list[str | None]) -> None: + """ + Validate that requested speakers are supported by the Instruct model. + + Args: + speakers (List[Optional[str]]): Speaker names for each sample. + + Raises: + ValueError: If any speaker is not supported. + """ + supported = self._supported_speakers_set() + if supported is None: + return + + bad = [] + for spk in speakers: + if spk is None or spk == "": + continue + if str(spk).lower() not in supported: + bad.append(spk) + if bad: + raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}") + + def _is_probably_base64(self, s: str) -> bool: + if s.startswith("data:audio"): + return True + if ("/" not in s and "\\" not in s) and len(s) > 256: + return True + return False + + def _is_url(self, s: str) -> bool: + try: + u = urlparse(s) + return u.scheme in ("http", "https") and bool(u.netloc) + except Exception: + return False + + def _decode_base64_to_wav_bytes(self, b64: str) -> bytes: + if "," in b64 and b64.strip().startswith("data:"): + b64 = b64.split(",", 1)[1] + return base64.b64decode(b64) + + def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]: + if self._is_url(x): + with urllib.request.urlopen(x) as resp: + audio_bytes = resp.read() + with io.BytesIO(audio_bytes) as f: + audio, sr = sf.read(f, dtype="float32", always_2d=False) + elif self._is_probably_base64(x): + wav_bytes = self._decode_base64_to_wav_bytes(x) + with io.BytesIO(wav_bytes) as f: + audio, sr = sf.read(f, dtype="float32", always_2d=False) + else: + audio, sr = librosa.load(x, sr=None, mono=True) + + if audio.ndim > 1: + audio = np.mean(audio, axis=-1) + + return audio.astype(np.float32), int(sr) + + def _normalize_audio_inputs(self, audios: AudioLike | list[AudioLike]) -> list[tuple[np.ndarray, int]]: + """ + Normalize audio inputs into a list of (waveform, sr). + + Supported forms: + - str: wav path / URL / base64 audio string + - (np.ndarray, sr): waveform + sampling rate + - list of the above + + Args: + audios: + Audio input(s). + + Returns: + List[Tuple[np.ndarray, int]]: + List of (float32 waveform, original sr). + + Raises: + ValueError: If a numpy waveform is provided without sr. + """ + if isinstance(audios, list): + items = audios + else: + items = [audios] + + out: list[tuple[np.ndarray, int]] = [] + for a in items: + if isinstance(a, str): + out.append(self._load_audio_to_np(a)) + elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray): + out.append((a[0].astype(np.float32), int(a[1]))) + elif isinstance(a, np.ndarray): + raise ValueError("For numpy waveform input, pass a tuple (audio, sr).") + else: + raise TypeError(f"Unsupported audio input type: {type(a)}") + for i, a in enumerate(out): + if a[0].ndim > 1: + a[0] = np.mean(a[0], axis=-1).astype(np.float32) + out[i] = (a[0], a[1]) + return out + + def _ensure_list(self, x: MaybeList) -> list[Any]: + return x if isinstance(x, list) else [x] + + def _build_assistant_text(self, text: str) -> str: + return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + + def _build_ref_text(self, text: str) -> str: + return f"<|im_start|>assistant\n{text}<|im_end|>\n" + + def _build_instruct_text(self, instruct: str) -> str: + return f"<|im_start|>user\n{instruct}<|im_end|>\n" + + def _tokenize_texts(self, texts: list[str]) -> list[torch.Tensor]: + input_ids = [] + for text in texts: + input = self.processor(text=text, return_tensors="pt", padding=True) + input_id = input["input_ids"].to(self.device) + input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id + input_ids.append(input_id) + return input_ids + + def _merge_generate_kwargs( + self, + non_streaming_mode: bool | None = None, + do_sample: bool | None = None, + top_k: int | None = None, + top_p: float | None = None, + temperature: float | None = None, + repetition_penalty: float | None = None, + subtalker_dosample: bool | None = None, + subtalker_top_k: int | None = None, + subtalker_top_p: float | None = None, + subtalker_temperature: float | None = None, + max_new_tokens: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Merge user-provided generation arguments with defaults from `generate_config.json`. + + Rule: + - If the user explicitly passes a value (not None), use it. + - Otherwise, use the value from generate_config.json if present. + - Otherwise, fall back to the hard defaults. + + Args: + non_streaming_mode, do_sample, top_k, top_p, temperature, repetition_penalty, + subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens: + Common generation parameters. + **kwargs: + Other arguments forwarded to model.generate(). + + Returns: + Dict[str, Any]: Final kwargs to pass into model.generate(). + """ + hard_defaults = dict( + non_streaming_mode=False, + do_sample=True, + top_k=50, + top_p=1.0, + temperature=0.9, + repetition_penalty=1.05, + subtalker_dosample=True, + subtalker_top_k=50, + subtalker_top_p=1.0, + subtalker_temperature=0.9, + max_new_tokens=2048, + ) + + def pick(name: str, user_val: Any) -> Any: + if user_val is not None: + return user_val + if name in self.generate_defaults: + return self.generate_defaults[name] + return hard_defaults[name] + + merged = dict(kwargs) + merged.update( + non_streaming_mode=pick("non_streaming_mode", non_streaming_mode), + do_sample=pick("do_sample", do_sample), + top_k=pick("top_k", top_k), + top_p=pick("top_p", top_p), + temperature=pick("temperature", temperature), + repetition_penalty=pick("repetition_penalty", repetition_penalty), + subtalker_dosample=pick("subtalker_dosample", subtalker_dosample), + subtalker_top_k=pick("subtalker_top_k", subtalker_top_k), + subtalker_top_p=pick("subtalker_top_p", subtalker_top_p), + subtalker_temperature=pick("subtalker_temperature", subtalker_temperature), + max_new_tokens=pick("max_new_tokens", max_new_tokens), + ) + return merged + + # voice clone model + @torch.inference_mode() + def create_voice_clone_prompt( + self, + ref_audio: AudioLike | list[AudioLike], + ref_text: str | list[str | None] | None = None, + x_vector_only_mode: bool | list[bool] = False, + ) -> list[VoiceClonePromptItem]: + """ + Build voice-clone prompt items from reference audio (and optionally reference text) using Base model. + + Modes: + - x_vector_only_mode=True: + Only speaker embedding is used to clone voice; ref_text/ref_code are ignored. + This is mutually exclusive with ICL. + - x_vector_only_mode=False: + ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required, + because the model continues/conditions on the reference text + reference speech codes. + + Batch behavior: + - ref_audio can be a single item or a list. + - ref_text and x_vector_only_mode can be scalars or lists. + - If any of them are lists with length > 1, lengths must match. + + Audio input: + - str: local wav path / URL / base64 + - (np.ndarray, sr): waveform + sampling rate + + Args: + ref_audio: + Reference audio(s) used to extract: + - ref_code via `model.speech_tokenizer.encode(...)` + - ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k) + ref_text: + Reference transcript(s). Required when x_vector_only_mode=False (ICL mode). + x_vector_only_mode: + Whether to use speaker embedding only. If False, ICL mode will be used. + + Returns: + List[VoiceClonePromptItem]: + List of prompt items that can be converted into `voice_clone_prompt` dict. + + Raises: + ValueError: + - If x_vector_only_mode=False but ref_text is missing. + - If batch lengths mismatch. + """ + if self.model.tts_model_type != "base": + raise ValueError( + f"model with \ntokenizer_type: {self.model.tokenizer_type}\n" + f"tts_model_size: {self.model.tts_model_size}\n" + f"tts_model_type: {self.model.tts_model_type}\n" + "does not support create_voice_clone_prompt, Please check Model Card or Readme for more details." + ) + + ref_audio_list = self._ensure_list(ref_audio) + ref_text_list = ( + self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list)) + ) + xvec_list = ( + self._ensure_list(x_vector_only_mode) + if isinstance(x_vector_only_mode, list) + else ([x_vector_only_mode] * len(ref_audio_list)) + ) + + if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list): + raise ValueError( + f"Batch size mismatch: ref_audio={len(ref_audio_list)}, " + f"ref_text={len(ref_text_list)}, " + f"x_vector_only_mode={len(xvec_list)}" + ) + + normalized = self._normalize_audio_inputs(ref_audio_list) + + ref_wavs_for_code: list[np.ndarray] = [] + ref_sr_for_code: list[int] = [] + for wav, sr in normalized: + ref_wavs_for_code.append(wav) + ref_sr_for_code.append(sr) + + if len(set(ref_sr_for_code)) == 1: + enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0]) + ref_codes = enc.audio_codes + else: + ref_codes = [] + for wav, sr in normalized: + ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0]) + + items: list[VoiceClonePromptItem] = [] + for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)): + if not xvec_only: + if rtext is None or rtext == "": + rtext = "For profile run" + logger.warning( + f"ref_text is required when x_vector_only_mode=False (ICL mode). " + f"Bad index={i}. Please check if it is profile run or " + f"you missed to provide ref_text." + ) + # raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}") + + wav_resample = wav + if sr != self.model.speaker_encoder_sample_rate: + wav_resample = librosa.resample( + y=wav_resample.astype(np.float32), orig_sr=int(sr), target_sr=self.model.speaker_encoder_sample_rate + ) + + spk_emb = self.model.extract_speaker_embedding( + audio=wav_resample, sr=self.model.speaker_encoder_sample_rate + ) + + items.append( + VoiceClonePromptItem( + ref_code=None if xvec_only else code, + ref_spk_embedding=spk_emb, + x_vector_only_mode=bool(xvec_only), + icl_mode=bool(not xvec_only), + ref_text=rtext, + ) + ) + return items + + def _prompt_items_to_voice_clone_prompt(self, items: list[VoiceClonePromptItem]) -> dict[str, Any]: + return dict( + ref_code=[it.ref_code for it in items], + ref_spk_embedding=[it.ref_spk_embedding for it in items], + x_vector_only_mode=[it.x_vector_only_mode for it in items], + icl_mode=[it.icl_mode for it in items], + ) + + # voice clone model + @torch.no_grad() + def generate_voice_clone( + self, + text: str | list[str], + language: str | list[str] = None, + ref_audio: AudioLike | list[AudioLike] | None = None, + ref_text: str | list[str | None] | None = None, + x_vector_only_mode: bool | list[bool] = False, + voice_clone_prompt: dict[str, Any] | list[VoiceClonePromptItem] | None = None, + **kwargs: Any, + ) -> tuple[list[np.ndarray], int]: + """ + Voice clone speech using the Base model. + + You can provide either: + - (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR + - `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR + - a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`. + + `ref_audio` Supported forms: + - str: wav path / URL / base64 audio string + - (np.ndarray, sr): waveform + sampling rate + - list of the above + + Input flexibility: + - text/language can be scalar or list. + - prompt can be single or batch. + - If batch mode (len(text)>1), lengths must match. + + Args: + text: + Text(s) to synthesize. + language: + Language(s) for each sample. + ref_audio: + Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided. + ref_text: + Reference text(s) used for ICL mode (required when x_vector_only_mode=False). + x_vector_only_mode: + If True, only speaker embedding is used (ignores ref_text/ref_code). + If False, ICL mode is used automatically. + voice_clone_prompt: + list[VoiceClonePromptItem] from `create_voice_clone_prompt`. + **kwargs: + Additional generation options. Common keys include `non_streaming_mode`, `do_sample`, `top_k`, `top_p`, + `temperature`, `repetition_penalty`, `subtalker_dosample`, `subtalker_top_k`, `subtalker_top_p`, + `subtalker_temperature`, and `max_new_tokens`. Any other keyword arguments supported by HuggingFace + Transformers `generate()` can also be passed and will be forwarded to + `Qwen3TTSForConditionalGeneration.generate(...)`. + + Returns: + Tuple[List[np.ndarray], int]: + (wavs, sample_rate) + + Raises: + ValueError: + If batch sizes mismatch or required prompt inputs are missing. + """ + if self.model.tts_model_type != "base": + raise ValueError( + f"model with \ntokenizer_type: {self.model.tokenizer_type}\n" + f"tts_model_size: {self.model.tts_model_size}\n" + f"tts_model_type: {self.model.tts_model_type}\n" + "does not support generate_voice_clone, Please check Model Card or Readme for more details." + ) + + texts = self._ensure_list(text) + languages = ( + self._ensure_list(language) + if isinstance(language, list) + else ([language] * len(texts) if language is not None else ["Auto"] * len(texts)) + ) + if len(languages) == 1 and len(texts) > 1: + languages = languages * len(texts) + if len(texts) != len(languages): + raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}") + + self._validate_languages(languages) + + if voice_clone_prompt is None: + if ref_audio is None: + # For profile run + sample_rate = int(self.model.speaker_encoder_sample_rate) + # Use a 1-second silent clip to satisfy padding requirements. + ref_audio = (np.zeros(sample_rate, dtype=np.float32), sample_rate) + logger.warning( + "ref_audio is not provided. Using a 1-second silent clip " + "to satisfy padding requirements. Please check if it is " + "profile run or you missed to provide ref_audio." + ) + prompt_items = self.create_voice_clone_prompt( + ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode + ) + if len(prompt_items) == 1 and len(texts) > 1: + prompt_items = prompt_items * len(texts) + if len(prompt_items) != len(texts): + raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}") + voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items) + ref_texts_for_ids = [it.ref_text for it in prompt_items] + else: + if isinstance(voice_clone_prompt, list): + prompt_items = voice_clone_prompt + if len(prompt_items) == 1 and len(texts) > 1: + prompt_items = prompt_items * len(texts) + if len(prompt_items) != len(texts): + raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}") + voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items) + ref_texts_for_ids = [it.ref_text for it in prompt_items] + else: + voice_clone_prompt_dict = voice_clone_prompt + ref_texts_for_ids = None + + input_texts = [self._build_assistant_text(t) for t in texts] + input_ids = self._tokenize_texts(input_texts) + + ref_ids = None + if ref_texts_for_ids is not None: + ref_ids = [] + for i, rt in enumerate(ref_texts_for_ids): + if rt is None or rt == "": + ref_ids.append(None) + else: + ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0] + ref_ids.append(ref_tok) + + gen_kwargs = self._merge_generate_kwargs(**kwargs) + + talker_codes_list, _ = self.model.generate( + input_ids=input_ids, + ref_ids=ref_ids, + voice_clone_prompt=voice_clone_prompt_dict, + languages=languages, + **gen_kwargs, + ) + + codes_for_decode = [] + for i, codes in enumerate(talker_codes_list): + ref_code_list = voice_clone_prompt_dict.get("ref_code", None) + if ref_code_list is not None and ref_code_list[i] is not None: + codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0)) + else: + codes_for_decode.append(codes) + + wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode]) + + wavs_out: list[np.ndarray] = [] + for i, wav in enumerate(wavs_all): + ref_code_list = voice_clone_prompt_dict.get("ref_code", None) + if ref_code_list is not None and ref_code_list[i] is not None: + ref_len = int(ref_code_list[i].shape[0]) + total_len = int(codes_for_decode[i].shape[0]) + cut = int(ref_len / max(total_len, 1) * wav.shape[0]) + wavs_out.append(wav[cut:]) + else: + wavs_out.append(wav) + + return wavs_out, fs + + # voice design model + @torch.no_grad() + def generate_voice_design( + self, + text: str | list[str], + instruct: str | list[str], + language: str | list[str] = None, + **kwargs: Any, + ) -> tuple[list[np.ndarray], int]: + """ + Generate speech with the VoiceDesign model using natural-language style instructions. + + Args: + text: + Text(s) to synthesize. + language: + Language(s) for each sample. + instruct: + Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction). + **kwargs: + Additional generation options. Common keys include `non_streaming_mode`, `do_sample`, `top_k`, `top_p`, + `temperature`, `repetition_penalty`, `subtalker_dosample`, `subtalker_top_k`, `subtalker_top_p`, + `subtalker_temperature`, and `max_new_tokens`. Any other keyword arguments supported by HuggingFace + Transformers `generate()` can also be passed and will be forwarded to + `Qwen3TTSForConditionalGeneration.generate(...)`. + + Returns: + Tuple[List[np.ndarray], int]: + (wavs, sample_rate) + """ + if self.model.tts_model_type != "voice_design": + raise ValueError( + f"model with \ntokenizer_type: {self.model.tokenizer_type}\n" + f"tts_model_size: {self.model.tts_model_size}\n" + f"tts_model_type: {self.model.tts_model_type}\n" + "does not support generate_voice_design, Please check Model Card or Readme for more details." + ) + + texts = self._ensure_list(text) + languages = ( + self._ensure_list(language) + if isinstance(language, list) + else ([language] * len(texts) if language is not None else ["Auto"] * len(texts)) + ) + instructs = self._ensure_list(instruct) + + if len(languages) == 1 and len(texts) > 1: + languages = languages * len(texts) + if len(instructs) == 1 and len(texts) > 1: + instructs = instructs * len(texts) + + if not (len(texts) == len(languages) == len(instructs)): + raise ValueError( + f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}" + ) + + self._validate_languages(languages) + + input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts]) + + instruct_ids: list[torch.Tensor | None] = [] + for ins in instructs: + if ins is None or ins == "": + instruct_ids.append(None) + else: + instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0]) + + gen_kwargs = self._merge_generate_kwargs(**kwargs) + + talker_codes_list, _ = self.model.generate( + input_ids=input_ids, + instruct_ids=instruct_ids, + languages=languages, + **gen_kwargs, + ) + + wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list]) + return wavs, fs + + # custom voice model + @torch.no_grad() + def generate_custom_voice( + self, + text: str | list[str], + speaker: str | list[str], + language: str | list[str] = None, + instruct: str | list[str] | None = None, + **kwargs: Any, + ) -> tuple[list[np.ndarray], int]: + """ + Generate speech with the CustomVoice model using a predefined speaker id, + optionally controlled by instruction text. + + Args: + text: + Text(s) to synthesize. + language: + Language(s) for each sample. + speaker: + Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive). + instruct: + Optional instruction(s). If None, treated as empty (no instruction). + **kwargs: + Additional generation options. Common keys include `non_streaming_mode`, `do_sample`, `top_k`, `top_p`, + `temperature`, `repetition_penalty`, `subtalker_dosample`, `subtalker_top_k`, `subtalker_top_p`, + `subtalker_temperature`, and `max_new_tokens`. Any other keyword arguments supported by HuggingFace + Transformers `generate()` can also be passed and will be forwarded to + `Qwen3TTSForConditionalGeneration.generate(...)`. + + Returns: + Tuple[List[np.ndarray], int]: + (wavs, sample_rate) + + Raises: + ValueError: + If any speaker/language is unsupported or batch sizes mismatch. + """ + if self.model.tts_model_type != "custom_voice": + raise ValueError( + f"model with \ntokenizer_type: {self.model.tokenizer_type}\n" + f"tts_model_size: {self.model.tts_model_size}\n" + f"tts_model_type: {self.model.tts_model_type}\n" + "does not support generate_custom_voice, Please check Model Card or Readme for more details." + ) + + texts = self._ensure_list(text) + languages = ( + self._ensure_list(language) + if isinstance(language, list) + else ([language] * len(texts) if language is not None else ["Auto"] * len(texts)) + ) + speakers = self._ensure_list(speaker) + if self.model.tts_model_size in "0b6": # for 0b6 model, instruct is not supported + instruct = None + instructs = ( + self._ensure_list(instruct) + if isinstance(instruct, list) + else ([instruct] * len(texts) if instruct is not None else [""] * len(texts)) + ) + + if len(languages) == 1 and len(texts) > 1: + languages = languages * len(texts) + if len(speakers) == 1 and len(texts) > 1: + speakers = speakers * len(texts) + if len(instructs) == 1 and len(texts) > 1: + instructs = instructs * len(texts) + + if not (len(texts) == len(languages) == len(speakers) == len(instructs)): + raise ValueError( + f"Batch size mismatch: text={len(texts)}, " + f"language={len(languages)}, speaker={len(speakers)}, " + f"instruct={len(instructs)}" + ) + + self._validate_languages(languages) + self._validate_speakers(speakers) + + input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts]) + + instruct_ids: list[torch.Tensor | None] = [] + for ins in instructs: + if ins is None or ins == "": + instruct_ids.append(None) + else: + instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0]) + + gen_kwargs = self._merge_generate_kwargs(**kwargs) + + talker_codes_list, _ = self.model.generate( + input_ids=input_ids, + instruct_ids=instruct_ids, + languages=languages, + speakers=speakers, + **gen_kwargs, + ) + + wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list]) + return wavs, fs + + def get_supported_speakers(self) -> list[str] | None: + """ + List supported speaker names for the current model. + + This is a convenience wrapper around `model.get_supported_speakers()`. + If the underlying model does not expose speaker constraints (returns None), + this method also returns None. + + Returns: + Optional[List[str]]: + - A sorted list of supported speaker names (lowercased), if available. + - None if the model does not provide supported speakers. + """ + supported = self._supported_speakers_set() + if supported is None: + return None + return sorted(supported) + + def get_supported_languages(self) -> list[str] | None: + """ + List supported language names for the current model. + + This is a convenience wrapper around `model.get_supported_languages()`. + If the underlying model does not expose language constraints (returns None), + this method also returns None. + + Returns: + Optional[List[str]]: + - A sorted list of supported language names (lowercased), if available. + - None if the model does not provide supported languages. + """ + supported = self._supported_languages_set() + if supported is None: + return None + return sorted(supported) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e50211988ee973f390428de0267ed55434592f --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py @@ -0,0 +1,417 @@ +# Copyright 2026 The Alibaba Qwen team. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import io +import urllib.request +from urllib.parse import urlparse + +import librosa +import numpy as np +import soundfile as sf +import torch +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoConfig, AutoFeatureExtractor, AutoModel + +from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config +from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import ( + Qwen3TTSTokenizerV2EncoderOutput, + Qwen3TTSTokenizerV2Model, +) +from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config +from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import ( + Qwen3TTSTokenizerV1EncoderOutput, + Qwen3TTSTokenizerV1Model, +) + +AudioInput = ( + str # wav path, or base64 string + | np.ndarray # 1-D float array + | list[str] + | list[np.ndarray] +) + + +class Qwen3TTSTokenizer: + """ + A wrapper for Qwen3 TTS Tokenizer 25Hz/12Hz with HuggingFace-style loading. + + - from_pretrained(): loads speech tokenizer model via AutoModel and feature_extractor via AutoFeatureExtractor. + - encode(): supports wav path(s), base64 audio string(s), numpy array(s). + - decode(): accepts either the raw model encode output, or a minimal dict/list-of-dicts. + + Notes: + - For numpy array input, you must pass `sr` so the audio can be resampled to model sample rate. + - Returned audio is float32 numpy arrays and the output sample rate. + """ + + def __init__(self): + self.model = None + self.feature_extractor = None + self.config = None + self.device = None + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer": + """ + Initialize tokenizer with HuggingFace `from_pretrained` style. + + Args: + pretrained_model_name_or_path (str): + HuggingFace repo id or local directory. + **kwargs (Any): + Forwarded to `AutoModel.from_pretrained(...)` directly. + Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="eager". + + Returns: + Qwen3TTSTokenizer: + Initialized instance with `model`, `feature_extractor`, `config`. + """ + inst = cls() + + AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config) + AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model) + + AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config) + AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model) + + inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) + inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) + inst.config = inst.model.config + + inst.device = getattr(inst.model, "device", None) + if inst.device is None: + # fallback: infer from first parameter device + try: + inst.device = next(inst.model.parameters()).device + except StopIteration: + inst.device = torch.device("cpu") + + return inst + + def _is_probably_base64(self, s: str) -> bool: + if s.startswith("data:audio"): + return True + # Heuristic: no filesystem path separators and long enough. + if ("/" not in s and "\\" not in s) and len(s) > 256: + return True + return False + + def _is_url(self, s: str) -> bool: + try: + u = urlparse(s) + return u.scheme in ("http", "https") and bool(u.netloc) + except Exception: + return False + + def _decode_base64_to_wav_bytes(self, b64: str) -> bytes: + # Accept both "data:audio/wav;base64,...." and raw base64 + if "," in b64 and b64.strip().startswith("data:"): + b64 = b64.split(",", 1)[1] + return base64.b64decode(b64) + + def load_audio( + self, + x: str, + target_sr: int, + ) -> np.ndarray: + """ + Load audio from wav path or base64 string, then resample to target_sr. + + Args: + x (str): + A wav file path, or a base64 audio string (raw or data URL). + target_sr (int): + Target sampling rate. + + Returns: + np.ndarray: + 1-D float32 waveform at target_sr. + """ + if self._is_url(x): + with urllib.request.urlopen(x) as resp: + audio_bytes = resp.read() + with io.BytesIO(audio_bytes) as f: + audio, sr = sf.read(f, dtype="float32", always_2d=False) + elif self._is_probably_base64(x): + wav_bytes = self._decode_base64_to_wav_bytes(x) + with io.BytesIO(wav_bytes) as f: + audio, sr = sf.read(f, dtype="float32", always_2d=False) + else: + audio, sr = librosa.load(x, sr=None, mono=True) + + if audio.ndim > 1: + audio = np.mean(audio, axis=-1) + + if sr != target_sr: + audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr) + + return audio.astype(np.float32) + + def _normalize_audio_inputs( + self, + audios: AudioInput, + sr: int | None, + ) -> list[np.ndarray]: + """ + Normalize all supported input types into a list of 1-D numpy float32 waveforms + at `self.feature_extractor.sampling_rate`. + + Args: + audios (AudioInput): + - str: wav path OR base64 audio string + - np.ndarray: raw waveform (sr must be provided) + - list[str] / list[np.ndarray] + sr (Optional[int]): + Sampling rate for raw numpy input. Required if input is np.ndarray or list[np.ndarray]. + + Returns: + List[np.ndarray]: + List of float32 waveforms resampled to model input SR. + """ + target_sr = int(self.feature_extractor.sampling_rate) + + if isinstance(audios, (str, np.ndarray)): + audios = [audios] + + if len(audios) == 0: + return [] + + if isinstance(audios[0], str): + # wav path list or base64 list + return [self.load_audio(x, target_sr=target_sr) for x in audios] # type: ignore[arg-type] + + # numpy list + if sr is None: + raise ValueError("For numpy waveform input, you must provide `sr` (original sampling rate).") + + out: list[np.ndarray] = [] + for a in audios: # type: ignore[assignment] + if not isinstance(a, np.ndarray): + raise TypeError("Mixed input types are not supported. Use all paths/base64 or all numpy arrays.") + if a.ndim > 1: + a = np.mean(a, axis=-1) + if int(sr) != target_sr: + a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr) + out.append(a.astype(np.float32)) + return out + + def encode( + self, + audios: AudioInput, + sr: int | None = None, + return_dict: bool = True, + ) -> ( + Qwen3TTSTokenizerV1EncoderOutput + | Qwen3TTSTokenizerV2EncoderOutput + | tuple[list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None] + | tuple[list[torch.Tensor]] + ): + """ + Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz). + + Args: + audios (AudioInput): + Supported forms: + - np.ndarray: waveform (requires sr) + - list[np.ndarray]: waveforms (requires sr) + - str: wav path OR base64 audio string + - list[str]: wav paths and/or base64 strings + sr (Optional[int], default=None): + Original sampling rate for numpy waveform input. + return_dict (bool, default=True): + Forwarded to model.encode(...). If True, returns ModelOutput. + + Returns: + Qwen3TTSTokenizerV1EncoderOutput | Qwen3TTSTokenizerV2EncoderOutput | tuple: + Encoder output or tuple returned by model.encode. If return_dict=True, + returns a model-specific encoder output. For 25Hz models, this includes + audio_codes/xvectors/ref_mels; for 12Hz models, this includes audio_codes. + If return_dict=False, returns the raw tuple from model.encode. + """ + wavs = self._normalize_audio_inputs(audios, sr=sr) + + inputs = self.feature_extractor( + raw_audio=wavs, + sampling_rate=int(self.feature_extractor.sampling_rate), + return_tensors="pt", + ) + inputs = inputs.to(self.device).to(self.model.dtype) + + with torch.inference_mode(): + # model.encode expects (B, T) and (B, T) + enc = self.model.encode( + inputs["input_values"].squeeze(1), + inputs["padding_mask"].squeeze(1), + return_dict=return_dict, + ) + return enc + + def decode( + self, + encoded, + ) -> tuple[list[np.ndarray], int]: + """ + Decode back to waveform. + + Usage: + 1) Pass the raw output of `encode(...)` directly (recommended). + - 25Hz: expects fields audio_codes, xvectors, ref_mels + - 12Hz: expects field audio_codes + 2) Pass a dict or list[dict] (minimal form) for custom pipelines: + - 25Hz dict keys: {"audio_codes", "xvectors", "ref_mels"} + - 12Hz dict keys: {"audio_codes"} + Values can be torch tensors or numpy arrays. + + Args: + encoded (Any): + - ModelOutput returned by `encode()`, OR + - dict, OR + - list[dict] + + Returns: + Tuple[List[np.ndarray], int]: + - wavs: list of 1-D float32 numpy arrays + - sample_rate: int, model output sampling rate + """ + model_type = self.model.get_model_type() + + def _to_tensor(x, dtype=None): + if isinstance(x, torch.Tensor): + return x + x = np.asarray(x) + t = torch.from_numpy(x) + if dtype is not None: + t = t.to(dtype) + return t + + # Normalize `encoded` into the same shapes as the official demo uses. + if hasattr(encoded, "audio_codes"): + # ModelOutput from encode() + audio_codes_list = encoded.audio_codes + xvectors_list = getattr(encoded, "xvectors", None) + ref_mels_list = getattr(encoded, "ref_mels", None) + elif isinstance(encoded, dict): + audio_codes_list = encoded["audio_codes"] + xvectors_list = encoded.get("xvectors", None) + ref_mels_list = encoded.get("ref_mels", None) + elif isinstance(encoded, list): + # list of dicts + audio_codes_list = [e["audio_codes"] for e in encoded] + xvectors_list = [e["xvectors"] for e in encoded] if ("xvectors" in encoded[0]) else None + ref_mels_list = [e["ref_mels"] for e in encoded] if ("ref_mels" in encoded[0]) else None + else: + raise TypeError("`encoded` must be an encode output, a dict, or a list of dicts.") + + # Ensure list form for per-sample tensors + if isinstance(audio_codes_list, torch.Tensor): + # Could be a single sample tensor or an already padded batch tensor. + t = audio_codes_list + if t.dim() == 1: + # 25Hz single sample: (C,) -> (1, C) + t = t.unsqueeze(0) + elif t.dim() == 2: + # 12Hz single sample: (C, Q) -> (1, C, Q) + t = t.unsqueeze(0) + audio_codes_padded = t.to(self.device) + else: + # List[Tensor/np] + audio_codes_list = [_to_tensor(c, dtype=torch.long) for c in audio_codes_list] + audio_codes_padded = pad_sequence(audio_codes_list, batch_first=True, padding_value=0).to(self.device) + + with torch.inference_mode(): + if model_type == "qwen3_tts_tokenizer_25hz": + if xvectors_list is None or ref_mels_list is None: + raise ValueError("25Hz decode requires `xvectors` and `ref_mels`.") + + if isinstance(xvectors_list, torch.Tensor): + xvectors_batch = xvectors_list + if xvectors_batch.dim() == 1: # (D,) -> (1, D) + xvectors_batch = xvectors_batch.unsqueeze(0) + xvectors_batch = xvectors_batch.to(self.device).to(self.model.dtype) + else: + xvectors_list = [_to_tensor(x, dtype=torch.float32) for x in xvectors_list] + xvectors_batch = torch.stack(xvectors_list, dim=0).to(self.device).to(self.model.dtype) + + if isinstance(ref_mels_list, torch.Tensor): + ref_mels_padded = ref_mels_list + if ref_mels_padded.dim() == 2: # (T, M) -> (1, T, M) + ref_mels_padded = ref_mels_padded.unsqueeze(0) + ref_mels_padded = ref_mels_padded.to(self.device).to(self.model.dtype) + else: + ref_mels_list = [_to_tensor(m, dtype=torch.float32) for m in ref_mels_list] + ref_mels_padded = ( + pad_sequence(ref_mels_list, batch_first=True, padding_value=0) + .to(self.device) + .to(self.model.dtype) + ) + + dec = self.model.decode(audio_codes_padded, xvectors_batch, ref_mels_padded, return_dict=True) + wav_tensors = dec.audio_values + + elif model_type == "qwen3_tts_tokenizer_12hz": + dec = self.model.decode(audio_codes_padded, return_dict=True) + wav_tensors = dec.audio_values + + else: + raise ValueError(f"Unknown model type: {model_type}") + + wavs = [w.to(torch.float32).detach().cpu().numpy() for w in wav_tensors] + return wavs, int(self.model.get_output_sample_rate()) + + def get_model_type(self) -> str: + """ + Get the underlying tokenizer model type. + + Returns: + str: Model type string from `self.model.config.model_type` + (e.g. "qwen3_tts_tokenizer_25hz" / "qwen3_tts_tokenizer_12hz"). + """ + return self.model.get_model_type() + + def get_input_sample_rate(self) -> int: + """ + Get the expected input sample rate for encoding. + + Returns: + int: Input sample rate (Hz). + """ + return int(self.model.get_input_sample_rate()) + + def get_output_sample_rate(self) -> int: + """ + Get the output sample rate for decoded waveforms. + + Returns: + int: Output sample rate (Hz). + """ + return int(self.model.get_output_sample_rate()) + + def get_encode_downsample_rate(self) -> int: + """ + Get the encoder downsample rate (waveform samples per code step). + + Returns: + int: Encode downsample rate. + """ + return int(self.model.get_encode_downsample_rate()) + + def get_decode_upsample_rate(self) -> int: + """ + Get the decoder upsample rate (waveform samples per code step). + + Returns: + int: Decode upsample rate. + """ + return int(self.model.get_decode_upsample_rate()) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/__init__.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db6952f1db83b991b615c71b9f6fa0610755a366 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/__init__.py @@ -0,0 +1 @@ +# Qwen3 TTS 12Hz tokenizer package. diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3a6499e9435a58a4f1332972a9a64a0e7b6cee --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py @@ -0,0 +1,171 @@ +# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3TTSTokenizerV2 model configuration""" + +from transformers import MimiConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen3TTSTokenizerV2DecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2DecoderConfig`]. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + codebook_size (`int`, *optional*, defaults to 2048): + Number of entries in each residual codebook used for acoustic token quantization. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the hidden states and embeddings in the autoregressive transformer decoder. + max_position_embeddings (`int`, *optional*, defaults to 8000): + Maximum sequence length that the autoregressive decoder can handle. Determines positional embedding size. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period for rotary position embeddings (RoPE) applied to attention layers. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + Number of key and value attention heads used in grouped-query attention (if applicable). + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the attention projection layers. + sliding_window (`int`, *optional*, defaults to 72): + Window size for local attention mechanism, limiting attention context to improve efficiency. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the feed-forward (intermediate) layer in each transformer block. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function used in the feed-forward layers. + Supports `"silu"`, `"relu"`, `"gelu"`, etc. + layer_scale_initial_scale (`float`, *optional*, defaults to 0.01): + Initial value for LayerScale applied in transformer blocks, helping stabilize training. + rms_norm_eps (`float`, *optional*, defaults to 1e-5): + Epsilon value for RMS normalization layers to prevent division by zero. + num_hidden_layers (`int`, *optional*, defaults to 8): + Number of transformer blocks in the autoregressive decoder. + num_quantizers (`int`, *optional*, defaults to 16): + Number of residual vector quantizers used in the vocoder for fine-grained audio reconstruction. + upsample_rates (`Tuple[int]`, *optional*, defaults to `(8, 5, 4, 3)`): + Rate at which features are upsampled in the final waveform synthesis stage. + upsampling_ratios (`Tuple[int]`, *optional*, defaults to `(2, 2)`): + Ratios used in transposed convolutional layers to progressively upsample feature maps to waveform. + decoder_dim (`int`, *optional*, defaults to 1536): + Final dimensionality of the decoder's output before waveform generation. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability applied to attention weights in the decoder. + """ + + def __init__( + self, + codebook_size=2048, + hidden_size=1024, + latent_dim=1024, + max_position_embeddings=8000, + rope_theta=10000, + num_attention_heads=16, + num_key_value_heads=16, + attention_bias=False, + sliding_window=72, + intermediate_size=3072, + hidden_act="silu", + layer_scale_initial_scale=0.01, + rms_norm_eps=1e-5, + num_hidden_layers=8, + num_quantizers=16, + upsample_rates=(8, 5, 4, 3), + upsampling_ratios=(2, 2), + decoder_dim=1536, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.codebook_size = codebook_size + self.hidden_size = hidden_size + self.latent_dim = latent_dim + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.sliding_window = sliding_window + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.layer_scale_initial_scale = layer_scale_initial_scale + self.rms_norm_eps = rms_norm_eps + self.num_hidden_layers = num_hidden_layers + self.num_quantizers = num_quantizers + self.upsample_rates = upsample_rates + self.upsampling_ratios = upsampling_ratios + self.decoder_dim = decoder_dim + self.attention_dropout = attention_dropout + + @property + def layer_types(self): + """ + All layer in code2wav should be sliding attention + """ + return ["sliding_attention"] * self.num_hidden_layers + + +class Qwen3TTSTokenizerV2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2Config`]. + It is used to instantiate a Qwen3TTSTokenizerV2Model model according to the specified + sub-models configurations, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model. + decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model. + """ + + model_type = "qwen3_tts_tokenizer_12hz" + sub_configs = { + "encoder_config": MimiConfig, + "decoder_config": Qwen3TTSTokenizerV2DecoderConfig, + } + + def __init__( + self, + encoder_config=None, + decoder_config=None, + encoder_valid_num_quantizers=16, + input_sample_rate=24000, + output_sample_rate=24000, + decode_upsample_rate=1920, + encode_downsample_rate=1920, + **kwargs, + ): + super().__init__(**kwargs) + if encoder_config is None: + encoder_config = {} + logger.info("encoder_config is None. Initializing encoder with default values") + if decoder_config is None: + decoder_config = {} + logger.info("decoder_config is None. Initializing decoder with default values") + + self.encoder_config = MimiConfig(**encoder_config) + self.decoder_config = Qwen3TTSTokenizerV2DecoderConfig(**decoder_config) + + self.encoder_valid_num_quantizers = encoder_valid_num_quantizers + self.input_sample_rate = input_sample_rate + self.output_sample_rate = output_sample_rate + self.decode_upsample_rate = decode_upsample_rate + self.encode_downsample_rate = encode_downsample_rate + + +__all__ = ["Qwen3TTSTokenizerV2Config", "Qwen3TTSTokenizerV2DecoderConfig"] diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..31abe254b70b6e99b52b70814489defbaebfce2f --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -0,0 +1,1005 @@ +# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3TTSTokenizerV2 model.""" + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from transformers import MimiConfig, MimiModel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ModelOutput, auto_docstring, logging +from transformers.utils.deprecation import deprecate_kwarg + +from .configuration_qwen3_tts_tokenizer_v2 import ( + Qwen3TTSTokenizerV2Config, + Qwen3TTSTokenizerV2DecoderConfig, +) + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring +class Qwen3TTSTokenizerV2EncoderOutput(ModelOutput): + r""" + audio_codes (`List[torch.LongTensor]`): + Discrete code embeddings computed using `model.encode`, each tensor has shape (codes_length_i, num_quantizers). + """ + + audio_codes: list[torch.LongTensor] = None + + +@dataclass +@auto_docstring +class Qwen3TTSTokenizerV2DecoderOutput(ModelOutput): + r""" + audio_values (`List[torch.FloatTensor]`): + Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1. + Each tensor has shape (segment_length_i). + """ + + audio_values: list[torch.FloatTensor] = None + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@auto_docstring +class Qwen3TTSTokenizerV2DecoderPreTrainedModel(PreTrainedModel): + config: Qwen3TTSTokenizerV2DecoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = False + _supports_attention_backend = True + + +class Qwen3TTSTokenizerV2CausalConvNet(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation=1, + stride=1, + groups=1, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + ) + self.stride = stride + self.kernel_size = (kernel_size - 1) * dilation + 1 + self.dilation = dilation + self.padding = self.kernel_size - self.stride + + def _get_extra_padding_for_conv1d(self, hidden_state: torch.Tensor) -> int: + length = hidden_state.shape[-1] + n_frames = (length - self.kernel_size + self.padding) / self.stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * self.stride + (self.kernel_size - self.padding) + return ideal_length - length + + def forward(self, hidden_state): + extra_padding = self._get_extra_padding_for_conv1d(hidden_state) + hidden_state = F.pad(hidden_state, (self.padding, extra_padding), mode="constant", value=0) + return self.conv(hidden_state).contiguous() + + +class Qwen3TTSTokenizerV2CausalTransConvNet(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1): + super().__init__() + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride) + + pad = kernel_size - stride + self.left_pad = math.ceil(pad) + self.right_pad = pad = self.left_pad + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = hidden_state[..., self.left_pad : hidden_state.shape[-1] - self.right_pad] + return hidden_state.contiguous() + + +class Qwen3TTSTokenizerV2ConvNeXtBlock(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dwconv = Qwen3TTSTokenizerV2CausalConvNet( + dim, + dim, + kernel_size=7, + groups=dim, + dilation=1, + ) + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(1e-6 * torch.ones(dim)) + + def forward(self, hidden_states): + input = hidden_states + + hidden_states = self.dwconv(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.norm(hidden_states) + hidden_states = self.pwconv1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.pwconv2(hidden_states) + + hidden_states = self.gamma * hidden_states + + hidden_states = hidden_states.permute(0, 2, 1) + + hidden_states = input + hidden_states + + return hidden_states + + +class Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3TTSTokenizerV2DecoderAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + self.sliding_window = config.sliding_window + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3TTSTokenizerV2DecoderMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3TTSTokenizerV2DecoderRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen3TTSTokenizerV2DecoderRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3TTSTokenizerV2DecoderLayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://huggingface.co/papers/2103.17239). + This rescales diagonally the residual outputs close to 0, with a learnt scale. + """ + + def __init__(self, config): + super().__init__() + channels = config.hidden_size + initial_scale = config.layer_scale_initial_scale + self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True)) + + def forward(self, x: torch.Tensor): + return self.scale * x + + +class Qwen3TTSTokenizerV2DecoderTransformerLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3TTSTokenizerV2DecoderAttention(config, layer_idx) + self.mlp = Qwen3TTSTokenizerV2DecoderMlp(config) + self.input_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps) + self.self_attn_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config) + self.mlp_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config) + self.attention_type = "sliding_attention" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self.self_attn_layer_scale(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.mlp_layer_scale(hidden_states) + + return hidden_states + + +@auto_docstring +class Qwen3TTSTokenizerV2DecoderTransformerModel(Qwen3TTSTokenizerV2DecoderPreTrainedModel): + _can_record_outputs = { + "hidden_states": Qwen3TTSTokenizerV2DecoderTransformerLayer, + "attentions": Qwen3TTSTokenizerV2DecoderAttention, + } + + def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [ + Qwen3TTSTokenizerV2DecoderTransformerLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + self.window_size = config.sliding_window + + self.input_proj = nn.Linear(config.latent_dim, config.hidden_size) + self.output_proj = nn.Linear(config.hidden_size, config.latent_dim) + + # Initialize weights and apply final processing + self.post_init() + + # Note: @check_model_inputs decorator removed for vLLM compatibility + # The decorator causes "unexpected keyword argument 'inputs_embeds'" error + @auto_docstring + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + cache_position=None, + **kwargs, + ) -> BaseModelOutputWithPast: + if input_ids is not None: + raise ValueError("input_ids is not expected") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = self.input_proj(inputs_embeds) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.output_proj(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper + by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://huggingface.co/papers/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + + self.act1 = SnakeBeta(dim) + self.conv1 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=7, dilation=dilation) + self.act2 = SnakeBeta(dim) + self.conv2 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=1) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.act1(hidden_state) + hidden_state = self.conv1(hidden_state) + hidden_state = self.act2(hidden_state) + hidden_state = self.conv2(hidden_state) + return hidden_state + residual + + +class Qwen3TTSTokenizerV2DecoderDecoderBlock(Qwen3TTSTokenizerV2DecoderPreTrainedModel): + def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx): + super().__init__(config) + in_dim = config.decoder_dim // 2**layer_idx + out_dim = config.decoder_dim // 2 ** (layer_idx + 1) + upsample_rate = config.upsample_rates[layer_idx] + + block = [ + SnakeBeta(in_dim), + Qwen3TTSTokenizerV2CausalTransConvNet(in_dim, out_dim, 2 * upsample_rate, upsample_rate), + ] + + for dilation in (1, 3, 9): + block.append(Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(out_dim, dilation)) + + self.block = nn.ModuleList(block) + + def forward(self, hidden): + for block in self.block: + hidden = block(hidden) + return hidden + + +class EuclideanCodebook(nn.Module): + def __init__( + self, + dim: int, + codebook_size: int, + epsilon: float = 1e-5, + ): + super().__init__() + self.dim = dim + self.codebook_size = codebook_size + self.epsilon = epsilon + + self.cluster_usage = nn.Parameter(torch.ones(codebook_size)) + self.embedding_sum = nn.Parameter(torch.zeros(codebook_size, dim)) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + embedding = self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None] + quantized = F.embedding(codes, embedding) + return quantized + + +class VectorQuantization(nn.Module): + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: int | None = None, + epsilon: float = 1e-5, + ): + super().__init__() + if codebook_dim is None: + codebook_dim = dim + + requires_projection = codebook_dim != dim + + self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() + self.epsilon = epsilon + self._codebook = EuclideanCodebook(dim=codebook_dim, codebook_size=codebook_size, epsilon=epsilon) + self.codebook_size = codebook_size + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + quantized = self._codebook.decode(codes) + quantized = self.project_out(quantized) + quantized = quantized.transpose(1, 2) + return quantized + + +class ResidualVectorQuantization(nn.Module): + def __init__(self, *, num_quantizers: int, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + quantized = torch.zeros([1], device=codes.device)[0] + for idx, layer_codes in enumerate(codes): + layer = self.layers[idx] + assert isinstance(layer, VectorQuantization) + quantized = quantized + layer.decode(layer_codes) + return quantized + + +class ResidualVectorQuantizer(nn.Module): + def __init__( + self, + dimension: int = 128, + input_dimension: int | None = None, + output_dimension: int | None = None, + n_q: int = 8, + q_dropout: bool = False, + no_quantization_rate: float = 0.0, + bins: int = 1024, + decay: float = 0.99, + force_projection: bool = False, + ): + super().__init__() + self.max_n_q = n_q + self.n_q = n_q + self.q_dropout = q_dropout + self.no_quantization_rate = no_quantization_rate + self.dimension = dimension + self.input_dimension = input_dimension or dimension + self.output_dimension = output_dimension or dimension + self.bins = bins + self.decay = decay + self.input_proj: torch.nn.Module + self.output_proj: torch.nn.Module + if self.input_dimension == self.dimension and not force_projection: + self.input_proj = torch.nn.Identity() + else: + self.input_proj = torch.nn.Conv1d(self.input_dimension, self.dimension, 1, bias=False) + if self.output_dimension == self.dimension and not force_projection: + self.output_proj = torch.nn.Identity() + else: + self.output_proj = torch.nn.Conv1d(self.dimension, self.output_dimension, 1, bias=False) + self.vq = ResidualVectorQuantization(dim=self.dimension, codebook_size=self.bins, num_quantizers=self.n_q) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + codes = codes.transpose(0, 1) + quantized = self.vq.decode(codes) + quantized = self.output_proj(quantized) + return quantized + + +class SplitResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer with separate projections for the first quantizer and the rest. + + Args: + n_q (int): Number of residual vector quantizers used. + n_q_semantic (int): Number of residual vector quantizers used for the semantic quantizer. + **kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both. + """ + + def __init__( + self, + *, + n_q: int = 8, + n_q_semantic: int = 1, + **kwargs: Any, + ): + super().__init__() + assert n_q > n_q_semantic, ( + f"Number of quantizers {n_q} must be larger than the number of semantic quantizers {n_q_semantic}." + ) + self.max_n_q = n_q + self.n_q_semantic = n_q_semantic + self.n_q_acoustic = n_q - n_q_semantic + q_dropout = kwargs.pop("q_dropout", False) + self.rvq_first = ResidualVectorQuantizer(n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs) + self.rvq_rest = ResidualVectorQuantizer( + n_q=n_q - n_q_semantic, + force_projection=True, + q_dropout=q_dropout, + **kwargs, + ) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + # codes is [B, K, T], with T frames, K nb of codebooks. + quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic]) + if codes.shape[1] > self.n_q_semantic: + quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :]) + return quantized + + +class Qwen3TTSTokenizerV2Decoder(Qwen3TTSTokenizerV2DecoderPreTrainedModel): + def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig): + super().__init__(config) + self.total_upsample = np.prod(config.upsample_rates + config.upsampling_ratios) + self.pre_transformer = Qwen3TTSTokenizerV2DecoderTransformerModel._from_config(config) + + self.quantizer = SplitResidualVectorQuantizer( + dimension=config.codebook_dim // 2, + n_q=config.num_quantizers, + n_q_semantic=1, + bins=config.codebook_size, + input_dimension=config.codebook_dim, + output_dimension=config.codebook_dim, + ) + + self.pre_conv = Qwen3TTSTokenizerV2CausalConvNet( + config.codebook_dim, + config.latent_dim, + kernel_size=3, + ) + + upsample = [] + for factor in config.upsampling_ratios: + upsample.append( + nn.ModuleList( + [ + Qwen3TTSTokenizerV2CausalTransConvNet(config.latent_dim, config.latent_dim, factor, factor), + Qwen3TTSTokenizerV2ConvNeXtBlock(config.latent_dim), + ] + ) + ) + self.upsample = nn.ModuleList(upsample) + + decoder = [Qwen3TTSTokenizerV2CausalConvNet(config.latent_dim, config.decoder_dim, 7)] + for i in range(len(config.upsample_rates)): + decoder.append(Qwen3TTSTokenizerV2DecoderDecoderBlock(config, i)) + output_dim = config.decoder_dim // 2 ** len(config.upsample_rates) + decoder += [ + SnakeBeta(output_dim), + Qwen3TTSTokenizerV2CausalConvNet(output_dim, 1, 7), + ] + self.decoder = nn.ModuleList(decoder) + + self.post_init() + + def forward(self, codes): + if codes.shape[1] != self.config.num_quantizers: + raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}") + + hidden = self.quantizer.decode(codes) + hidden = self.pre_conv(hidden).transpose(1, 2) + + hidden = self.pre_transformer(inputs_embeds=hidden).last_hidden_state + hidden = hidden.permute(0, 2, 1) + for blocks in self.upsample: + for block in blocks: + hidden = block(hidden) + wav = hidden + for block in self.decoder: + wav = block(wav) + return wav.clamp(min=-1, max=1) + + def chunked_decode(self, codes, chunk_size=300, left_context_size=25): + wavs = [] + start_index = 0 + while start_index < codes.shape[-1]: + end_index = min(start_index + chunk_size, codes.shape[-1]) + context_size = left_context_size if start_index - left_context_size > 0 else start_index + codes_chunk = codes[..., start_index - context_size : end_index] + wav_chunk = self(codes_chunk) + wavs.append(wav_chunk[..., context_size * self.total_upsample :]) + start_index = end_index + return torch.cat(wavs, dim=-1) + + +class Qwen3TTSTokenizerV2Encoder(MimiModel): + def __init__(self, config: MimiConfig): + super().__init__(config) + self.config = config + + self.upsample = None + self.decoder_transformer = None + self.decoder = None + + self.post_init() + + +@auto_docstring +class Qwen3TTSTokenizerV2PreTrainedModel(PreTrainedModel): + config: Qwen3TTSTokenizerV2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = False + _supports_attention_backend = True + + +@auto_docstring( + custom_intro=""" + The Qwen3TTSTokenizerV2 model. + """ +) +class Qwen3TTSTokenizerV2Model(Qwen3TTSTokenizerV2PreTrainedModel): + def __init__(self, config: Qwen3TTSTokenizerV2Config): + super().__init__(config) + self.config = config + + self.encoder_valid_num_quantizers = config.encoder_valid_num_quantizers + + self.input_sample_rate = config.input_sample_rate + self.output_sample_rate = config.output_sample_rate + + self.decode_upsample_rate = config.decode_upsample_rate + self.encode_downsample_rate = config.encode_downsample_rate + + self.encoder = Qwen3TTSTokenizerV2Encoder._from_config(self.config.encoder_config) + self.decoder = Qwen3TTSTokenizerV2Decoder._from_config(self.config.decoder_config) + + self.post_init() + + def get_model_type(self): + return self.config.model_type + + def get_input_sample_rate(self): + return self.input_sample_rate + + def get_output_sample_rate(self): + return self.output_sample_rate + + def get_encode_downsample_rate(self): + return self.encode_downsample_rate + + def get_decode_upsample_rate(self): + return self.decode_upsample_rate + + def encode( + self, + input_values: torch.Tensor, + padding_mask: torch.Tensor | None = None, + return_dict: bool | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None] | Qwen3TTSTokenizerV2EncoderOutput: + """ + Encodes the input audio waveform into discrete codes. + + Args: + input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Float values of the input audio waveform. + padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Indicates which inputs are to be ignored due to padding, + where elements are either 1 for *not masked* or 0 for *masked*. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoded_frames = self.encoder.encode(input_values=input_values.unsqueeze(1), return_dict=True) + audio_codes = encoded_frames.audio_codes[:, : self.encoder_valid_num_quantizers] + audio_codes = [ + code[..., : -(-mask.sum() // self.encode_downsample_rate)].transpose(0, 1) + for code, mask in zip(audio_codes, padding_mask) + ] + + if not return_dict: + return (audio_codes,) + + return Qwen3TTSTokenizerV2EncoderOutput(audio_codes) + + def decode( + self, + audio_codes: torch.Tensor, + return_dict: bool | None = None, + ) -> tuple[torch.Tensor, torch.Tensor] | Qwen3TTSTokenizerV2DecoderOutput: + """ + Decodes the given frames into an output audio waveform. + + Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be + trimmed. + + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length, num_quantizers)`, *optional*): + Discrete code embeddings computed using `model.encode`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + audio_values = self.decoder.chunked_decode(audio_codes.transpose(1, 2)).squeeze(1) + + audio_lengths = (audio_codes[..., 0] > 0).sum(1) * self.decode_upsample_rate + audio_values = [a[:length] for a, length in zip(audio_values, audio_lengths)] + + if not return_dict: + return (audio_values,) + + return Qwen3TTSTokenizerV2DecoderOutput(audio_values) + + +__all__ = ["Qwen3TTSTokenizerV2Model", "Qwen3TTSTokenizerV2PreTrainedModel"] diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/__init__.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f49344f00f1c3f1eb43531a1fc2d5814671c9b4 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/__init__.py @@ -0,0 +1 @@ +# Qwen3 TTS 25Hz tokenizer package. diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..74272f936ce6d86cfb40a3e728a2e0b0b8d3b01e --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py @@ -0,0 +1,332 @@ +# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3TTSTokenizerV1 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen3TTSTokenizerV1DecoderDiTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavDiT. + It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + The dimension of the model. + num_hidden_layers (`int`, *optional*, defaults to 22): + The number of transformer blocks in the DiT model. + num_attention_heads (`int`, *optional*, defaults to 16): + The number of attention heads in each transformer block. + ff_mult (`int`, *optional*, defaults to 2): + The multiplier for the feedforward layer in each transformer block. + emb_dim (`int`, *optional*, defaults to 512): + The dimension of the embedding layer. + head_dim (`int`, *optional*, defaults to 64): + The dimension of each attention head. + repeats (`int`, *optional*, defaults to 2): + The number of times the codec embeddings are repeated. + num_embeds (`int`, *optional*, defaults to 8193): + The number of unique embeddings in the codec. + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the transformer blocks. + + enc_emb_dim (`int`, *optional*, defaults to 192): + The dimension of the pre-trained speaker embedding. + enc_dim (`int`, *optional*, defaults to 128): + The dimension of the encoder output. + enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`): + A list of output channels for each TDNN/SERes2Net layer in the encoder. + enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`): + A list of kernel sizes for each layer in the encoder. + enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`): + A list of dilations for each layer in the encoder. + enc_attention_channels (`int`, *optional*, defaults to 64): + The number of attention channels in the SqueezeExcitationBlock. + enc_res2net_scale (`int`, *optional*, defaults to 2): + The scale of the Res2Net block in the encoder. + enc_se_channels (`int`, *optional*, defaults to 64): + The number of output channels after squeeze in the SqueezeExcitationBlock. + """ + + model_type = "qwen3_tts_tokenizer_v1_decoder_dit" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=22, + num_attention_heads=16, + ff_mult=2, + emb_dim=512, + head_dim=64, + rope_theta=10000.0, + max_position_embeddings=32768, + block_size=24, + look_ahead_layers=[10], + look_backward_layers=[0, 20], + repeats=2, + num_embeds=8193, + mel_dim=80, + dropout=0.1, + enc_emb_dim=192, + enc_dim=128, + enc_channels=[256, 256, 256, 256, 768], + enc_kernel_sizes=[5, 3, 3, 3, 1], + enc_dilations=[1, 2, 3, 4, 1], + enc_attention_channels=64, + enc_res2net_scale=2, + enc_se_channels=64, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.ff_mult = ff_mult + self.emb_dim = emb_dim + self.head_dim = head_dim + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.block_size = block_size + self.look_ahead_layers = look_ahead_layers + self.look_backward_layers = look_backward_layers + self.repeats = repeats + self.num_embeds = num_embeds + self.mel_dim = mel_dim + self.dropout = dropout + self.enc_emb_dim = enc_emb_dim + self.enc_dim = enc_dim + self.enc_channels = enc_channels + self.enc_kernel_sizes = enc_kernel_sizes + self.enc_dilations = enc_dilations + self.enc_attention_channels = enc_attention_channels + self.enc_res2net_scale = enc_res2net_scale + self.enc_se_channels = enc_se_channels + super().__init__(**kwargs) + + +class Qwen3TTSTokenizerV1DecoderBigVGANConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavBigVGAN module. + It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms. + + Args: + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + upsample_initial_channel (`int`, *optional*, defaults to 1536): + The number of channels in the initial upsampling layer. + resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`): + A list of kernel sizes for each residual block. + resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A list of dilation sizes for each residual block. + upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`): + A list of upsampling rates for each upsampling layer. + upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`): + A list of kernel sizes for each upsampling layer. + """ + + model_type = "qwen3_tts_tokenizer_v1_decoder_bigvgan" + + def __init__( + self, + mel_dim=80, + upsample_initial_channel=1536, + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[5, 3, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 7, 4, 4, 4, 4], + **kwargs, + ): + self.mel_dim = mel_dim + self.upsample_initial_channel = upsample_initial_channel + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + super().__init__(**kwargs) + + +class Qwen3TTSTokenizerV1DecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1DecoderConfig`]. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + dit_config ([`DiT_Args`], *optional*): + Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms. + bigvgan_config ([`BigVGAN_Args`], *optional*): + Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms. + """ + + model_type = "qwen3_tts_tokenizer_v1_decoder" + sub_configs = { + "dit_config": Qwen3TTSTokenizerV1DecoderDiTConfig, + "bigvgan_config": Qwen3TTSTokenizerV1DecoderBigVGANConfig, + } + + def __init__(self, dit_config=None, bigvgan_config=None, **kwargs): + if dit_config is None: + dit_config = {} + if bigvgan_config is None: + bigvgan_config = {} + self.dit_config = Qwen3TTSTokenizerV1DecoderDiTConfig(**dit_config) + self.bigvgan_config = Qwen3TTSTokenizerV1DecoderBigVGANConfig(**bigvgan_config) + super().__init__(**kwargs) + + +class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1 Encoder. + + The encoder typically takes mel-spectrogram features and produces high-level + audio representations, then (optionally) applies an Audio-VQ module (e.g., GRVQ) + to discretize continuous representations into codes. + + Args: + n_mels (`int`, *optional*, defaults to 128): + Number of mel bins in the input mel-spectrogram. + n_ctx (`int`, *optional*, defaults to 1500): + Maximum input sequence length (in frames/tokens) for the encoder. + n_state (`int`, *optional*, defaults to 1280): + Hidden size (model dimension) of the encoder transformer. + n_head (`int`, *optional*, defaults to 20): + Number of attention heads in each transformer layer. + n_layer (`int`, *optional*, defaults to 32): + Number of transformer layers. + n_window (`int`, *optional*, defaults to 100): + Window size used by the model for local attention / chunking (implementation-dependent). + output_dim (`int`, *optional*, defaults to 3584): + Output feature dimension produced by the encoder head (before/after projection, implementation-dependent). + + grad_checkpointing (`bool`, *optional*, defaults to `False`): + Whether to enable gradient checkpointing to reduce memory usage during training. + enable_mp (`bool`, *optional*, defaults to `False`): + Whether to enable model parallel features (implementation-dependent). + audio_sequence_parallel (`bool`, *optional*, defaults to `False`): + Whether to enable sequence parallelism for audio branch (implementation-dependent). + + audio_vq_type (`str`, *optional*, defaults to `"GRVQ"`): + Type of audio vector-quantization module. Common choices: `"GRVQ"`, `"RVQ"`, etc. + audio_vq_layers (`int`, *optional*, defaults to 6): + Number of VQ layers / quantizers (e.g., number of residual quantizers for RVQ/GRVQ-like designs). + audio_vq_codebook_size (`int`, *optional*, defaults to 32768): + Size of each codebook (number of entries). + audio_vq_codebook_dim (`int`, *optional*, defaults to 1280): + Dimension of codebook vectors (often equals encoder hidden size). + audio_vq_pe (`bool`, *optional*, defaults to `True`): + Whether to use positional encoding (or position embeddings) inside the VQ module. + audio_vq_ds_rate (`int`, *optional*, defaults to 2): + Downsampling rate applied before VQ (e.g., temporal downsample factor). + """ + + model_type = "qwen3_tts_tokenizer_v1_encoder" + + def __init__( + self, + n_mels=128, + n_ctx=1500, + n_state=1280, + n_head=20, + n_layer=32, + n_window=100, + output_dim=3584, + grad_checkpointing=False, + enable_mp=False, + audio_sequence_parallel=False, + audio_vq_type="GRVQ", + audio_vq_layers=6, + audio_vq_codebook_size=32768, + audio_vq_codebook_dim=1280, + audio_vq_pe=True, + audio_vq_ds_rate=2, + **kwargs, + ): + super().__init__(**kwargs) + self.n_mels = n_mels + self.n_ctx = n_ctx + self.n_state = n_state + self.n_head = n_head + self.n_layer = n_layer + self.n_window = n_window + self.output_dim = output_dim + self.grad_checkpointing = grad_checkpointing + self.enable_mp = enable_mp + self.audio_sequence_parallel = audio_sequence_parallel + self.audio_vq_type = audio_vq_type + self.audio_vq_layers = audio_vq_layers + self.audio_vq_codebook_size = audio_vq_codebook_size + self.audio_vq_codebook_dim = audio_vq_codebook_dim + self.audio_vq_pe = audio_vq_pe + self.audio_vq_ds_rate = audio_vq_ds_rate + + +class Qwen3TTSTokenizerV1Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1Config`]. + It is used to instantiate a Qwen3TTSTokenizerV1Model model according to the specified + sub-models configurations, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model. + decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model. + """ + + model_type = "qwen3_tts_tokenizer_25hz" + sub_configs = { + "encoder_config": Qwen3TTSTokenizerV1EncoderConfig, + "decoder_config": Qwen3TTSTokenizerV1DecoderConfig, + } + + def __init__( + self, + encoder_config=None, + decoder_config=None, + input_sample_rate=24000, + output_sample_rate=24000, + decode_upsample_rate=1920, + encode_downsample_rate=1920, + **kwargs, + ): + super().__init__(**kwargs) + if encoder_config is None: + encoder_config = {} + logger.info("encoder_config is None. Initializing encoder with default values") + if decoder_config is None: + decoder_config = {} + logger.info("decoder_config is None. Initializing decoder with default values") + + self.encoder_config = Qwen3TTSTokenizerV1EncoderConfig(**encoder_config) + self.decoder_config = Qwen3TTSTokenizerV1DecoderConfig(**decoder_config) + + self.input_sample_rate = input_sample_rate + self.output_sample_rate = output_sample_rate + self.decode_upsample_rate = decode_upsample_rate + self.encode_downsample_rate = encode_downsample_rate + + +__all__ = [ + "Qwen3TTSTokenizerV1Config", + "Qwen3TTSTokenizerV1EncoderConfig", + "Qwen3TTSTokenizerV1DecoderConfig", + "Qwen3TTSTokenizerV1DecoderBigVGANConfig", + "Qwen3TTSTokenizerV1DecoderDiTConfig", +] diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..bceafc98e39d44e01f218da3007ffa4d7661e473 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py @@ -0,0 +1,1525 @@ +# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3TTSTokenizerV1 model.""" + +import math +from dataclasses import dataclass + +import numpy as np +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.nn.utils.rnn import pad_sequence +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.utils import ModelOutput, auto_docstring, logging +from transformers.utils.hub import cached_file + +from .configuration_qwen3_tts_tokenizer_v1 import ( + Qwen3TTSTokenizerV1Config, + Qwen3TTSTokenizerV1DecoderBigVGANConfig, + Qwen3TTSTokenizerV1DecoderConfig, + Qwen3TTSTokenizerV1DecoderDiTConfig, + Qwen3TTSTokenizerV1EncoderConfig, +) +from .vq.speech_vq import WhisperEncoderVQ, XVectorExtractor +from .vq.whisper_encoder import get_mel_audio, get_T_after_cnn + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring +class Qwen3TTSTokenizerV1EncoderOutput(ModelOutput): + r""" + audio_codes (`List[torch.LongTensor]`): + Discrete code embeddings computed using `model.encode`, each tensor has shape (codes_length_i,). + xvectors (`List[torch.FloatTensor]`): + X-vector embeddings computed using `model.encode`, each tensor has shape (xvector_dim,). + ref_mels (`List[torch.FloatTensor]`): + Reference mel spectrogram computed using `model.encode`, each tensor has shape (mel_length_i, mel_dim,). + """ + + audio_codes: list[torch.LongTensor] = None + xvectors: list[torch.FloatTensor] = None + ref_mels: list[torch.FloatTensor] = None + + +@dataclass +@auto_docstring +class Qwen3TTSTokenizerV1DecoderOutput(ModelOutput): + r""" + audio_values (`List[torch.FloatTensor]`): + Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1. + Each tensor has shape (segment_length_i). + """ + + audio_values: list[torch.FloatTensor] = None + + +@auto_docstring +class Qwen3TTSTokenizerV1DecoderPreTrainedModel(PreTrainedModel): + config: Qwen3TTSTokenizerV1DecoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = False + _supports_attention_backend = True + + +@auto_docstring +class Qwen3TTSTokenizerV1EncoderPreTrainedModel(PreTrainedModel): + config: Qwen3TTSTokenizerV1EncoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = False + _supports_attention_backend = True + + +class Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + """ + + def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: bool | None = False, + code_embed_uncond: bool | None = None, + apply_cfg: bool | None = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self.is_causal = False + + self.to_q = nn.Linear(config.hidden_size, self.inner_dim) + self.to_k = nn.Linear(config.hidden_size, self.inner_dim) + self.to_v = nn.Linear(config.hidden_size, self.inner_dim) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = position_embeddings + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = AdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper + by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://huggingface.co/papers/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size) -> torch.Tensor: + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate") + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels + ) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate") + out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise TypeError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + + def forward(self, x): + return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias) + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + causal_type="1", + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + ), + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + ), + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + ), + ] + ) + + if causal_type == "1": + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + else: + self.convs2 = nn.ModuleList( + [ + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + ), + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + ), + CausalConv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + if causal_type == "2": + self.pre_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=self._get_padding(kernel_size, 1), + ) + self.pre_act = TorchActivation1d(activation=SnakeBeta(channels)) + else: + self.pre_conv = nn.Identity() + self.pre_act = nn.Identity() + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, x): + hidden_states = self.pre_conv(x) + hidden_states = self.pre_act(hidden_states) + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + x = x + hidden_states + return x + + +@auto_docstring +class Qwen3TTSTokenizerV1DecoderBigVGANModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel): + config: Qwen3TTSTokenizerV1DecoderBigVGANConfig + + def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 5, 1, padding=2) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock( + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + dilation, + "1" if layer_idx > 1 else "2", + ) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze(1) + + +@auto_docstring +class Qwen3TTSTokenizerV1DecoderDiTModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel): + config: Qwen3TTSTokenizerV1DecoderDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen3TTSTokenizerV1DecoderDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] * 2 + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + def optimized_scale(self, positive_flat, negative_flat): + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + return st_star + + @torch.no_grad() + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ): + noise_initialization = torch.randn( + [quantized_code.shape[0], 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype + ) + maximum_duration = quantized_code.shape[1] * self.repeats + initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + values = initial_state.clone() + for t0, t1 in zip(time_embedding[:-1], time_embedding[1:]): + dt = t1 - t0 + vt = ode_function(t0, values) + values = values + vt * dt + + generated_mel_spectrogram = values.permute(0, 2, 1) + return generated_mel_spectrogram + + +@auto_docstring +class Qwen3TTSTokenizerV1Decoder(Qwen3TTSTokenizerV1DecoderPreTrainedModel): + config: Qwen3TTSTokenizerV1DecoderConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen3TTSTokenizerV1DecoderDiTModel", "Qwen3TTSTokenizerV1DecoderBigVGANModel"] + + def __init__(self, config: Qwen3TTSTokenizerV1DecoderConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen3TTSTokenizerV1Decoder must inference with fp32, " + "but flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen3TTSTokenizerV1Decoder " + "will fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen3TTSTokenizerV1Decoder does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.dit = Qwen3TTSTokenizerV1DecoderDiTModel._from_config(config.dit_config, attn_implementation=attn_impl) + self.bigvgan = Qwen3TTSTokenizerV1DecoderBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.dit.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + ) + + waveform = self.bigvgan(mel_spectrogram) + + return waveform + + +class Qwen3TTSTokenizerV1Encoder(Qwen3TTSTokenizerV1EncoderPreTrainedModel): + config: Qwen3TTSTokenizerV1EncoderConfig + + def __init__(self, config: Qwen3TTSTokenizerV1EncoderConfig): + super().__init__(config) + + self.tokenizer = WhisperEncoderVQ( + n_mels=config.n_mels, + n_ctx=config.n_ctx, + n_state=config.n_state, + n_head=config.n_head, + n_layer=config.n_layer, + n_window=config.n_window, + output_dim=config.output_dim, + grad_checkpointing=config.grad_checkpointing, + enable_mp=config.enable_mp, + audio_sequence_parallel=config.audio_sequence_parallel, + audio_vq_type=config.audio_vq_type, + audio_vq_layers=config.audio_vq_layers, + audio_vq_codebook_size=config.audio_vq_codebook_size, + audio_vq_codebook_dim=config.audio_vq_codebook_dim, + audio_vq_pe=config.audio_vq_pe, + audio_vq_ds_rate=config.audio_vq_ds_rate, + ) + + self.padding = True + self.audio_vq_ds_rate = self.tokenizer.audio_vq_ds_rate + + def speech2mel(self, speeches): + mels = [ + get_mel_audio(speech, padding=self.padding, audio_vq_ds_rate=self.audio_vq_ds_rate) + .to(speech.dtype) + .to(self.tokenizer.conv1.weight.device) + for speech in speeches + ] + return mels + + def mel2code(self, mels): + audio_mellens = [mel.size(-1) for mel in mels] + audio_aftercnnlens = [get_T_after_cnn(T) for T in audio_mellens] + audio_seqlens = [T + 2 for T in audio_aftercnnlens] + + with torch.no_grad(): + _, indices = self.tokenizer( + x_list=mels, + audio_mellens=audio_mellens, + audio_aftercnnlens=audio_aftercnnlens, + audio_seqlens=audio_seqlens, + return_indices=True, + ) + + indice_lens = [T // self.tokenizer.audio_vq_ds_rate for T in audio_aftercnnlens] + indices = pad_sequence(torch.split(indices, indice_lens), batch_first=True, padding_value=0) + + return indices, indice_lens + + def quantize_speech(self, speeches): + mels = self.speech2mel(speeches) + indices, indice_lens = self.mel2code(mels) + return indices, indice_lens + + +@auto_docstring +class Qwen3TTSTokenizerV1PreTrainedModel(PreTrainedModel): + config: Qwen3TTSTokenizerV1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = False + _supports_attention_backend = True + + +@auto_docstring( + custom_intro=""" + The Qwen3TTSTokenizerV1 model. + """ +) +class Qwen3TTSTokenizerV1Model(Qwen3TTSTokenizerV1PreTrainedModel): + def __init__(self, config: Qwen3TTSTokenizerV1Config): + super().__init__(config) + self.config = config + + self.input_sample_rate = config.input_sample_rate + self.output_sample_rate = config.output_sample_rate + + self.decode_upsample_rate = config.decode_upsample_rate + self.encode_downsample_rate = config.encode_downsample_rate + + self.encoder = Qwen3TTSTokenizerV1Encoder._from_config(self.config.encoder_config) + self.decoder = Qwen3TTSTokenizerV1Decoder._from_config(self.config.decoder_config) + + self.encoder_xvector_extractor = None + + self.post_init() + + def load_encoder_xvector_extractor(self, model_path): + self.encoder_xvector_extractor = XVectorExtractor(model_path) + + def get_model_type(self): + return self.config.model_type + + def get_input_sample_rate(self): + return self.input_sample_rate + + def get_output_sample_rate(self): + return self.output_sample_rate + + def get_encode_downsample_rate(self): + return self.encode_downsample_rate + + def get_decode_upsample_rate(self): + return self.decode_upsample_rate + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + cache_dir=None, + ignore_mismatched_sizes=False, + force_download=False, + local_files_only=False, + token=None, + revision="main", + use_safetensors=None, + weights_only=True, + **kwargs, + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + encoder_xvector_extractor_path = cached_file( + pretrained_model_name_or_path, + "campplus.onnx", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + if encoder_xvector_extractor_path is None: + raise ValueError(f"""{pretrained_model_name_or_path}/{encoder_xvector_extractor_path} not exists""") + model.load_encoder_xvector_extractor(encoder_xvector_extractor_path) + + return model + + def encode( + self, + input_values: torch.Tensor, + padding_mask: torch.Tensor | None = None, + return_dict: bool | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None] | Qwen3TTSTokenizerV1EncoderOutput: + """ + Encodes the input audio waveform into discrete codes. + + Args: + input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Float values of the input audio waveform. + padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Indicates which inputs are to be ignored due to padding, + where elements are either 1 for *not masked* or 0 for *masked*. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + wavs = [value[: mask.sum()] for value, mask in zip(input_values, padding_mask)] + + codes, codes_lens = self.encoder.quantize_speech(wavs) + codes = [c[:length] for c, length in zip(codes, codes_lens)] + + xvectors = [] + ref_mels = [] + for wav in wavs: + xvector, ref_mel = self.encoder_xvector_extractor.extract_code(wav.cpu().numpy()) + xvector = torch.tensor(xvector).to(wav.dtype).to(wav.device) + ref_mel = torch.tensor(ref_mel).to(wav.dtype).to(wav.device) + xvectors.append(xvector) + ref_mels.append(ref_mel) + + if not return_dict: + return (codes, xvectors, ref_mels) + + return Qwen3TTSTokenizerV1EncoderOutput(codes, xvectors, ref_mels) + + def decode( + self, + audio_codes: torch.Tensor, + xvectors: torch.Tensor, + ref_mels: torch.Tensor, + return_dict: bool | None = None, + ) -> tuple[torch.Tensor, torch.Tensor] | Qwen3TTSTokenizerV1DecoderOutput: + """ + Decodes the given frames into an output audio waveform. + + Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be + trimmed. + + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length)`, *optional*): + Discrete code embeddings computed using `model.encode`. + xvectors (`torch.FloatTensor` of shape `(batch_size, xvector_dim)`, *optional*): + X-vector embeddings computed using `model.encode`. + ref_mels (`torch.FloatTensor` of shape `(batch_size, mel_length, mel_dim)`, *optional*): + Reference mel spectrogram computed using `model.encode`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + audio_values = self.decoder(code=audio_codes, reference_mel=ref_mels, conditioning=xvectors) + + audio_lengths = (audio_codes > 0).sum(1) * self.decode_upsample_rate + audio_values = [a[:length] for a, length in zip(audio_values, audio_lengths)] + + if not return_dict: + return (audio_values,) + + return Qwen3TTSTokenizerV1DecoderOutput(audio_values) + + +__all__ = ["Qwen3TTSTokenizerV1Model", "Qwen3TTSTokenizerV1PreTrainedModel"] diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/__init__.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f4e7196412f2596686e3cc0b83040ac03890b0 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/__init__.py @@ -0,0 +1 @@ +# Qwen3 TTS 25Hz vector-quantization package. diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz new file mode 100644 index 0000000000000000000000000000000000000000..28ea26909dbdfd608aef67afc4d74d7961ae4bb6 Binary files /dev/null and b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz differ diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/core_vq.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..9c103a851e553d5596c4252a6af512595c29c1f5 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/core_vq.py @@ -0,0 +1,515 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import random +import typing as tp +from math import ceil + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn + + +def round_up_multiple(num, mult): + return ceil(num / mult) * mult + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +@torch.no_grad() +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + dists = -( + samples.pow(2).sum(1, keepdim=True) + - 2 * torch.matmul(samples, means.t()) + + means.t().pow(2).sum(0, keepdim=True) + ) + + buckets = dists.max(dim=-1).indices + del dists + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + return means, bins + + +def preprocess(x): + x = rearrange(x, "... d -> (...) d") + return x + + +def postprocess_emb(embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: float = 2.0, + ): + super().__init__() + self.decay = decay + self.codebook_size = codebook_size + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.inited = None + self.cluster_size = None + self.embed = None + self.embed_avg = None + self.training = True + + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + # distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited]) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size + expired_codes = cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + else: + print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}") + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + # sync buffers outside for efficiency + # distrib.broadcast_tensors(self.buffers()) + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x, buffers): + self.inited, self.cluster_size, self.embed, self.embed_avg = buffers + + shape = x.shape + # pre-process + x = preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind, buffers): + self.inited, self.cluster_size, self.embed, self.embed_avg = buffers + + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x, buffers): + self.inited, self.cluster_size, self.embed, self.embed_avg = buffers + + shape, dtype = x.shape, x.dtype + x = preprocess(x) + + self.init_embed_(x) + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + # Note: after ema update, there is a very small difference between codebooks on GPUs. + # The impact can be very small, ignore it. + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently, supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: int | None = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: float = 2.0, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + self.training = True + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x, buffers): + # x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x, buffers) + return embed_in + + def decode(self, embed_ind, buffers): + quantize = self._codebook.decode(embed_ind, buffers) + quantize = self.project_out(quantize) + # quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x, buffers): + device = x.device + # x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x, buffers) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + # quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class DistributedResidualVectorQuantization(nn.Module): + """Efficient distributed residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, quantize_dropout: bool = False, rand_num_quant: list | None = None, **kwargs): + super().__init__() + """ + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + """ + codebook_size, codebook_dim = ( + kwargs["codebook_size"], + kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"], + ) + kmeans_init = kwargs["kmeans_init"] + if isinstance(kmeans_init, bool): + if not kwargs["kmeans_init"]: + # use uniform init + embed = uniform_init(num_quantizers, codebook_size, codebook_dim) + inited = True + else: + # to perform kmeans init on first batch + embed = torch.zeros(num_quantizers, codebook_size, codebook_dim) + inited = False + elif isinstance(kmeans_init, str): + # use prepared kmeans init + embed = np.load(kmeans_init) + embed = torch.from_numpy(embed) + if embed.dim() == 2: + embed = embed.unsqueeze(0) + inited = True + else: + raise TypeError("kmeans_init should be either a bool or string path to init weights.") + + self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)])) + self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + self.q0_ds_ratio = 1 + if "q0_ds_ratio" in kwargs: + self.q0_ds_ratio = kwargs.pop("q0_ds_ratio") + + self.layers = nn.ModuleList() + for i in range(num_quantizers): + vq_args = dict(**kwargs) + vq = VectorQuantization(**vq_args) + self.layers.append(vq) + + self.quantize_dropout = quantize_dropout + self.rand_num_quant = rand_num_quant + + def forward(self, x, n_q: int | None = None): + quantized_out = torch.zeros_like(x) + residual = x + bb, cc, tt = x.shape + device = x.device + + all_losses = [] + all_indices = [] + all_sub_quants = [] + n_q = n_q or len(self.layers) + + should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None + if should_quantize_dropout: + rand_quantize_dropout_index = random.choice(self.rand_num_quant) + + null_indices_shape = (x.shape[0], x.shape[2]) + null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) + null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) + null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype) + + for quantizer_index, layer in enumerate(self.layers[:n_q]): + # dropout except the first quantizer + if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index: + all_indices.append(null_indices) + all_losses.append(null_loss) + all_sub_quants.append(null_sub_quant) + continue + + quant_in = residual + if self.q0_ds_ratio > 1 and quantizer_index == 0: + quant_in = F.interpolate(quant_in, size=[tt // 2]) + quantized, indices, loss = layer( + quant_in, + [ + self.inited[quantizer_index], + self.cluster_size[quantizer_index], + self.embed[quantizer_index], + self.embed_avg[quantizer_index], + ], + ) + if self.q0_ds_ratio > 1 and quantizer_index == 0: + quantized = F.interpolate(quantized, size=[tt]) + indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long() + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + all_sub_quants.append(quantized) + + # sync buffers after one forward step + # distrib.broadcast_tensors(self.buffers()) + out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants)) + + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: int | None = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for i, layer in enumerate(self.layers[:n_q]): + indices = layer.encode(residual, [self.inited[i], self.cluster_size[i], self.embed[i], self.embed_avg[i]]) + quantized = layer.decode(indices, [self.inited[i], self.cluster_size[i], self.embed[i], self.embed_avg[i]]) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices, [self.inited[i], self.cluster_size[i], self.embed[i], self.embed_avg[i]]) + quantized_out = quantized_out + quantized + return quantized_out + + +class DistributedGroupResidualVectorQuantization(nn.Module): + """Efficient distributed group residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/abs/2305.02765 + Group Then rvq + """ + + def __init__( + self, + *, + num_groups, + num_quantizers, + quantize_dropout: bool = False, + rand_num_quant: list | None = None, + **kwargs, + ): + super().__init__() + self.rvqs = nn.ModuleList( + [ + DistributedResidualVectorQuantization( + num_quantizers=num_quantizers, + quantize_dropout=quantize_dropout, + rand_num_quant=rand_num_quant, + **kwargs, + ) + for _ in range(num_groups) + ] + ) + self.num_groups = num_groups + + def forward(self, x, n_q: int | None = None): + x_lst = torch.chunk(x, chunks=self.num_groups, dim=1) + all_quantized_out = [] + all_indices = [] + all_losses = [] + for mod, item in zip(self.rvqs, x_lst): + quantized_out, out_indices, out_losses = mod(item, n_q) + all_quantized_out.append(quantized_out) + all_indices.append(out_indices) + all_losses.append(out_losses) + + out_losses = torch.stack(all_losses, dim=1).mean(dim=1) + + return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses + + def encode(self, x: torch.Tensor, n_q: int | None = None) -> torch.Tensor: + x_lst = torch.chunk(x, chunks=self.num_groups, dim=1) + return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1) + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1) + return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..805feb81fea17ae8d692654a4882e3c13d54a29d --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py @@ -0,0 +1,403 @@ +# Copyright 2026 The Alibaba Qwen team. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import operator +from itertools import accumulate + +import onnxruntime +import sox +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.compliance.kaldi as kaldi +from librosa.filters import mel as librosa_mel_fn +from torch import Tensor + +from .core_vq import DistributedGroupResidualVectorQuantization +from .whisper_encoder import Conv1d, ConvTranspose1d, WhisperEncoder + + +def dynamic_range_compression_torch(x, c=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * c) + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +class MelSpectrogramFeatures(nn.Module): + """ + Calculate the BigVGAN style mel spectrogram of an input signal. + Args: + filter_length (int): The number of samples in the filter window, + used for the Fourier Transform. Default is 1024. + hop_length (int): The number of samples between successive frames + (stride of the STFT). Default is 160. + win_length (int): The length of the window function applied to each frame, + usually less than or equal to the filter length. Default is 640. + n_mel_channels (int): The number of Mel-frequency channels to output + from the Mel-scale spectrogram. Default is 80. + mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. + Default is 0. + mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. + Default is 8000. + sampling_rate (int): The sampling rate of the audio data (in Hz). + Default is 16000. + sampling_rate_org (int, optional): The original sampling rate of the audio + data before any resampling (in Hz), if applicable. Default is None. + padding (str): The padding mode for the input signal. 'center' pads the signal + symmetrically around its center. Default is 'center'. + """ + + def __init__( + self, + filter_length=1024, + hop_length=160, + win_length=640, + n_mel_channels=80, + mel_fmin=0, + mel_fmax=8000, + sampling_rate=16000, + sampling_rate_org=None, + padding="center", + use_db=False, + ): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.sampling_rate = sampling_rate + self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate + self.mel_basis = {} + self.hann_window = {} + + def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: + with torch.no_grad(): + feats = self.extract(audio, **kwargs) + return feats + + def extract(self, audio, **kwargs): + if len(audio.shape) == 3: + audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2) + assert len(audio.shape) == 2 + + y = audio + if len(list(self.mel_basis.keys())) == 0: + mel = librosa_mel_fn( + sr=self.sampling_rate, + n_fft=self.filter_length, + n_mels=self.n_mel_channels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + ) + self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((self.filter_length - self.hop_length) / 2), int((self.filter_length - self.hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.hann_window[str(y.device)], + center=False, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +class XVectorExtractor(nn.Module): + def __init__(self, audio_codec_with_xvector): + super().__init__() + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ["CPUExecutionProvider"] + self.ort_session = onnxruntime.InferenceSession( + audio_codec_with_xvector, sess_options=option, providers=providers + ) + + self.tfm = sox.Transformer() + self.tfm.norm(db_level=-6) + + self.mel_ext = MelSpectrogramFeatures( + filter_length=1024, + hop_length=160, + win_length=640, + n_mel_channels=80, + mel_fmin=0, + mel_fmax=8000, + sampling_rate=16000, + ) + + def extract_code(self, audio): + with torch.no_grad(): + norm_audio = self.sox_norm(audio) + + norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0) + feat = kaldi.fbank(norm_audio, num_mel_bins=80, dither=0, sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + norm_embedding = self.ort_session.run( + None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten() + norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0) + + ref_mel = self.mel_ext.extract(audio=norm_audio) + + return norm_embedding.numpy(), ref_mel.permute(0, 2, 1).squeeze(0).numpy() + + def sox_norm(self, audio): + wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000) + return wav_norm + + +class WhisperEncoderVQ(WhisperEncoder): + def __init__( + self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + n_window: int = 1500, + output_dim: int = 512, + grad_checkpointing: bool = False, + enable_mp: bool = False, + audio_sequence_parallel: bool = False, + audio_vq_layers: int = -1, + audio_vq_type: str = "NULL", + audio_vq_codebook_size: int = 4096, + audio_vq_pe: bool = False, + audio_vq_commit_loss: float = 0.0, + audio_vq_out_commit_loss: float = 0.0, + audio_vq_no_quantize: bool = False, + audio_vq_ff_layer: int = 0, + audio_vq_threshold_ema_dead_code: float = 0.1, + audio_vq_codebook_dim: int = None, + audio_vq_ds_rate: int = None, + ): + super().__init__( + n_mels, + n_ctx, + n_state, + n_head, + n_layer, + n_window, + output_dim, + grad_checkpointing, + enable_mp, + audio_sequence_parallel, + ) + + self.audio_vq_layers = audio_vq_layers + self.audio_vq_type = audio_vq_type + self.audio_vq_codebook_size = audio_vq_codebook_size + self.audio_vq_pe = audio_vq_pe + self.audio_vq_commit_loss = audio_vq_commit_loss + self.audio_vq_out_commit_loss = audio_vq_out_commit_loss + self.audio_vq_no_quantize = audio_vq_no_quantize + self.audio_vq_ff_layer = audio_vq_ff_layer + + if audio_vq_layers > 0: + self.vq_feature_dim = self.n_state + self.audio_vq_ds_rate = 1 + else: + raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}") + + if self.audio_vq_ds_rate == audio_vq_ds_rate: + self.audio_vq_downsample = nn.Identity() + self.audio_vq_upsample = nn.Identity() + else: + assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0 + stride = audio_vq_ds_rate // self.audio_vq_ds_rate + self.audio_vq_downsample = Conv1d( + self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride + ) + self.audio_vq_upsample = ConvTranspose1d( + self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride + ) + self.audio_vq_ds_rate = audio_vq_ds_rate + + if audio_vq_type == "GRVQ": + self.audio_quantizer = DistributedGroupResidualVectorQuantization( + codebook_size=audio_vq_codebook_size, + dim=self.vq_feature_dim, + codebook_dim=self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim, + num_groups=1, + num_quantizers=1, + kmeans_init=False, + threshold_ema_dead_code=audio_vq_threshold_ema_dead_code, + ) + else: + raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}") + + if self.audio_vq_pe: + self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state) + + def _calc_quantize_activities(self, indices): + indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0) + vq_num_activities = sum(indices_onehot > 0) + vq_num_tokens = sum(indices_onehot) + return { + "vq_num_activities": vq_num_activities, + "vq_num_tokens": vq_num_tokens, + } + + def _do_quantize(self, x, pe=None, y=None): + """ + x: torch.Tensor, shape = (T, D) + q: torch.Tensor, shape = (T, D) + i: torch.Tensor, shape = (T) + """ + if self.audio_vq_out_commit_loss > 0: + x_teacher = x.clone() + x = x.unsqueeze(0) + + x = self.audio_vq_downsample(x.transpose(1, 2)) + x = x.transpose(1, 2) + + vq_stats = {} + + if self.audio_vq_type == "GRVQ": + if self.training: + raise NotImplementedError + else: + indices = self.audio_quantizer.encode(x) + x = self.audio_quantizer.decode(indices) + indices = indices.squeeze(2).squeeze(1) + + vq_stats.update(self._calc_quantize_activities(indices)) + + x, indices = x.squeeze(0), indices.squeeze(0) + if self.audio_vq_pe: + x = x + pe + x = self.project_after_vq_pe(x) + + x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2)) + x = x.transpose(1, 2).squeeze(0) + + if self.audio_vq_out_commit_loss > 0: + vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x) + vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss + + return x, indices, vq_stats + + def forward( + self, + x_list: list[Tensor], + audio_mellens: list[int], + audio_aftercnnlens: list[int], + audio_seqlens: list[int], + return_indices=False, + audio_pitches=None, + ): + """ + x : torch.Tensor, shape = (n_mels, n_ctx) + the mel spectrogram of the audio + """ + + aftercnn_x_list = [] + pe_for_vq_list = [] + for each_x in x_list: + each_x_split_list = each_x.split(self.n_window * 2, dim=1) + for each_x_split in each_x_split_list: + each_x_split = F.gelu(self.conv1(each_x_split)) + each_x_split = F.gelu(self.conv2(each_x_split)) + each_x_split = each_x_split.permute(1, 0) # L,D + + each_positional_embedding_split = self.positional_embedding[: each_x_split.shape[0]] + aftercnn_x_list.append(each_x_split + each_positional_embedding_split.to(each_x_split.dtype)) + + pe_for_vq_split = self.positional_embedding[: each_x_split.shape[0] // self.audio_vq_ds_rate] + pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype)) + + pe_for_vq = torch.cat(pe_for_vq_list, dim=0) + x = torch.cat(aftercnn_x_list, dim=0) + + output_list = [] + for item in audio_aftercnnlens: + while item > self.n_window: + output_list.append(self.n_window) + item -= self.n_window + output_list.append(item) + + cu_seqlens = list(accumulate(output_list, func=operator.add, initial=0)) + cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32) + + layer_id = 0 + + for block in self.blocks: + layer_id += 1 + + x = block(x, cu_seqlens=cu_seqlens) + + if self.audio_vq_layers == layer_id: # vq inside encoder + x, indices, vq_stats = self._do_quantize(x, pe_for_vq) + if return_indices: + return x, indices + + if self.avg_pooler: + x_list = x.split(audio_aftercnnlens, dim=0) + token_x_list = [] + for x in x_list: + x = x.permute(1, 0) + x = self.avg_pooler(x) + x = x.permute(1, 0) + token_x_list.append(x) + x = torch.cat(token_x_list, dim=0) + + x = self.ln_post(x) + + x = self.proj(x) + + output = torch.zeros((x.size(0) + len(audio_seqlens) * 2, x.size(1)), device=x.device, dtype=x.dtype) + + audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0)) + start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32) + end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1 + + audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool) + audio_tokens_mask[start_ids] = False + audio_tokens_mask[end_ids] = False + output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype) + output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype) + output[audio_tokens_mask] = x + + if self.audio_vq_type != "NULL": + return output, vq_stats + return output diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c3c987a0fac5b7b29cb4a1aacba4f2f1ff9a7d --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py @@ -0,0 +1,385 @@ +# Copyright 2026 The Alibaba Qwen team. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import operator +import os +from functools import cache +from itertools import accumulate + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func + except ImportError: + print( + "\n********\nWarning: flash-attn is not installed. " + "Will only run the manual PyTorch version. " + "Please install flash-attn for faster inference.\n********\n " + ) + flash_attn_varlen_func = None + + +N_FFT = 400 +HOP_LENGTH = 160 + + +@cache +def mel_filters(device, n_mels: int) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), + ) + """ + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + with np.load(filters_path, allow_pickle=False) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: str | np.ndarray | torch.Tensor, + n_mels: int = 80, + padding: int = 0, + device: str | torch.device | None = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + + +def get_T_after_cnn(l_in, dilation=1): + for padding, kernel_size, stride in eval("[(1,3,1)] + [(1,3,2)] "): + l_out = l_in + 2 * padding - dilation * (kernel_size - 1) - 1 + l_out = 1 + l_out // stride + l_in = l_out + return l_out + + +def get_mel_audio(audio, padding=False, audio_vq_ds_rate=1, n_mels=128): + audio_len = len(audio) + if padding: + reduction = 160 * 2 * audio_vq_ds_rate + audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len + mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad) + else: + mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T] + return mel + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class Conv1d(nn.Conv1d): + def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)) + + +class ConvTranspose1d(nn.ConvTranspose1d): + def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: + return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)) + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + self.use_flash_attention = True + + def forward( + self, + x: Tensor, + cu_seqlens=None, + ): + q = self.query(x) + k = self.key(x) + v = self.value(x) + + if self.use_flash_attention: + if flash_attn_varlen_func is None: + x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens) + else: + if q.dtype not in [torch.float16, torch.bfloat16]: + x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens) + self.use_flash_attention = False + else: + x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens) + else: + x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens) + + output = self.out(x) + return output + + def qkv_flash_attention(self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None): + n_ctx, n_state = q.shape + # scale = (n_state // self.n_head) ** -0.25 + q = q.view(n_ctx, self.n_head, -1) # (batch_size, seqlen, nheads, headdim) + k = k.view(n_ctx, self.n_head, -1) + v = v.view(n_ctx, self.n_head, -1) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + x = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0) + x = x.reshape(n_ctx, n_state) + return x + + def qkv_attention_manual(self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor): + n_ctx, n_state = q.shape + head_dim = n_state // self.n_head + scale = head_dim**-0.5 + + q = q.view(n_ctx, self.n_head, head_dim) + k = k.view(n_ctx, self.n_head, head_dim) + v = v.view(n_ctx, self.n_head, head_dim) + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + batch_size = len(seqlens) + max_seqlen = max(seqlens) + + q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device) + k_padded = torch.zeros_like(q_padded) + v_padded = torch.zeros_like(q_padded) + + for i in range(batch_size): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + seq_len = seqlens[i] + q_padded[i, :seq_len] = q[start_idx:end_idx] + k_padded[i, :seq_len] = k[start_idx:end_idx] + v_padded[i, :seq_len] = v[start_idx:end_idx] + + q_padded = q_padded.transpose(1, 2) + k_padded = k_padded.transpose(1, 2) + v_padded = v_padded.transpose(1, 2) + + attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None] + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) + + attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max) + + attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale + attn_scores = attn_scores + attn_mask + attn_weights = F.softmax(attn_scores, dim=-1) + + context = torch.matmul(attn_weights, v_padded) + + context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state) + + output_packed = torch.cat([context[i, : seqlens[i]] for i in range(batch_size)], dim=0) + + assert output_packed.shape == (n_ctx, n_state) + + return output_packed + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, enable_mp: bool = False, sequence_parallel: bool = False): + super().__init__() + n_mlp = n_state * 4 + self.attn_ln = nn.LayerNorm(n_state) + self.mlp_ln = nn.LayerNorm(n_state) + + self.attn = MultiHeadAttention(n_state, n_head) + self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) + + def forward(self, x: Tensor, cu_seqlens=None): + x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class WhisperEncoder(nn.Module): + def __init__( + self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + n_window: int = 1500, + output_dim: int = 512, + grad_checkpointing: bool = False, + enable_mp: bool = False, + audio_sequence_parallel: bool = False, + ): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + self.n_layer = n_layer + self.n_mels = n_mels + + self.blocks = nn.ModuleList( + [ + ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel) + for _ in range(n_layer) + ] + ) + self.ln_post = nn.LayerNorm(n_state) + self.avg_pooler = nn.AvgPool1d(2, stride=2) + + self.proj = torch.nn.Linear(n_state, output_dim) + + self.audio_bos_eos_token = nn.Embedding(2, output_dim) + + self.output_dim = output_dim + self.grad_checkpointing = grad_checkpointing + self.enable_mp = enable_mp + self.n_head = n_head + self.n_state = n_state + self.n_window = n_window + + self.audio_sequence_parallel = audio_sequence_parallel + + self.tp_world_size = 1 + + self.set_audio_sync() + + def set_audio_sync(self): + for name, param in self.named_parameters(): + if not name.startswith("blocks"): + setattr(param, "audio_sync", True) + + def forward( + self, x_list: list[Tensor], audio_mellens: list[int], audio_aftercnnlens: list[int], audio_seqlens: list[int] + ): + """ + x : torch.Tensor, shape = (n_mels, n_ctx) + the mel spectrogram of the audio + """ + + aftercnn_x_list = [] + for each_x in x_list: + each_x_split_list = each_x.split(self.n_window * 2, dim=1) + for each_x_split in each_x_split_list: + each_x_split = F.gelu(self.conv1(each_x_split)) + each_x_split = F.gelu(self.conv2(each_x_split)) + each_x_split = each_x_split.permute(1, 0) # L,D + each_positional_embedding_split = self.positional_embedding[: each_x_split.shape[0]] + aftercnn_x_list.append(each_x_split + each_positional_embedding_split.to(each_x_split.dtype)) + + x = torch.cat(aftercnn_x_list, dim=0) + + output_list = [] + for item in audio_aftercnnlens: + while item > self.n_window: + output_list.append(self.n_window) + item -= self.n_window + output_list.append(item) + + cu_seqlens = list(accumulate(output_list, func=operator.add, initial=0)) + cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32) + + layer_id = 0 + for block in self.blocks: + layer_id += 1 + x = block(x, cu_seqlens=cu_seqlens) + + if self.avg_pooler: + x_list = x.split(audio_aftercnnlens, dim=0) + token_x_list = [] + for x in x_list: + x = x.permute(1, 0) + x = self.avg_pooler(x) + x = x.permute(1, 0) + token_x_list.append(x) + x = torch.cat(token_x_list, dim=0) + + x = self.ln_post(x) + x = self.proj(x) + + output = torch.zeros((x.size(0) + len(audio_seqlens) * 2, x.size(1)), device=x.device, dtype=x.dtype) + + audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0)) + start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32) + end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1 + + audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool) + audio_tokens_mask[start_ids] = False + audio_tokens_mask[end_ids] = False + output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype) + output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype) + output[audio_tokens_mask] = x + return output + + def lock(self, layers: int): + self.conv1.requires_grad_(False) + self.conv2.requires_grad_(False) + for i in range(min(layers, len(self.blocks))): + self.blocks[i].requires_grad_(False) diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..747ca8f0cddb01cbfaab4a8ab59baaf8596c7e6b --- /dev/null +++ b/vllm_omni/model_executor/models/registry.py @@ -0,0 +1,82 @@ +from vllm.model_executor.models.registry import _VLLM_MODELS, _LazyRegisteredModel, _ModelRegistry + +_OMNI_MODELS = { + "Qwen2_5OmniForConditionalGeneration": ( + "qwen2_5_omni", + "qwen2_5_omni", + "Qwen2_5OmniForConditionalGeneration", + ), + "Qwen2_5OmniThinkerModel": ( + "qwen2_5_omni", + "qwen2_5_omni_thinker", + "Qwen2_5OmniThinkerForConditionalGeneration", + ), + "Qwen2_5OmniTalkerModel": ( + "qwen2_5_omni", + "qwen2_5_omni_talker", + "Qwen2_5OmniTalkerForConditionalGeneration", + ), + "Qwen2_5OmniToken2WavModel": ( + "qwen2_5_omni", + "qwen2_5_omni_token2wav", + "Qwen2_5OmniToken2WavForConditionalGenerationVLLM", + ), + "Qwen2_5OmniToken2WavDiTModel": ( + "qwen2_5_omni", + "qwen2_5_omni_token2wav", + "Qwen2_5OmniToken2WavModel", + ), + "Qwen2ForCausalLM_old": ("qwen2_5_omni", "qwen2_old", "Qwen2ForCausalLM"), # need to discuss + # Qwen3 Omni MoE models + "Qwen3OmniMoeForConditionalGeneration": ( + "qwen3_omni", + "qwen3_omni", + "Qwen3OmniMoeForConditionalGeneration", + ), + "Qwen3OmniMoeThinkerForConditionalGeneration": ( + "qwen3_omni", + "qwen3_omni_moe_thinker", + "Qwen3OmniMoeThinkerForConditionalGeneration", + ), + "Qwen3OmniMoeTalkerForConditionalGeneration": ( + "qwen3_omni", + "qwen3_omni_moe_talker", + "Qwen3OmniMoeTalkerForConditionalGeneration", + ), + "Qwen3OmniMoeCode2Wav": ( + "qwen3_omni", + "qwen3_omni_code2wav", + "Qwen3OmniMoeCode2Wav", + ), + "Qwen3TTSForConditionalGeneration": ( + "qwen3_tts", + "qwen3_tts", + "Qwen3TTSModelForGeneration", + ), +} + + +_VLLM_OMNI_MODELS = { + **_VLLM_MODELS, + **_OMNI_MODELS, +} + + +OmniModelRegistry = _ModelRegistry( + { + **{ + model_arch: _LazyRegisteredModel( + module_name=f"vllm.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() + }, + **{ + model_arch: _LazyRegisteredModel( + module_name=f"vllm_omni.model_executor.models.{mod_folder}.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_folder, mod_relname, cls_name) in _OMNI_MODELS.items() + }, + } +) diff --git a/vllm_omni/model_executor/models/utils.py b/vllm_omni/model_executor/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7f73067b3bc6e5c6eb8f399c1e976ba6c0fed863 --- /dev/null +++ b/vllm_omni/model_executor/models/utils.py @@ -0,0 +1,39 @@ +import torch +from vllm.model_executor.models.utils import maybe_prefix + + +def add_prefix_to_loaded_weights(weights: set[str], prefix: str) -> set[str]: + """ + Add a prefix to the names of the loaded weights. + """ + return {maybe_prefix(prefix, name) for name in weights} + + +def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + if lst.numel() == 0: + return [] + + # Move to CPU and convert to list once (High Speedup) + # using .item() inside a loop is very slow. + data_list = lst.detach().cpu().tolist() + + # Calculate max on the list or tensor (Tensor max is fast enough) + max_val = int(torch.max(lst).item()) + + # Pre-allocate buckets + ranges: list[list[int]] = [[] for _ in range((max_val // interval) + 1)] + + for num in data_list: + index = int(num // interval) + ranges[index].append(num) + + return ranges + + +def safe_tensor_reshape(tensor: torch.Tensor, shape: tuple) -> torch.Tensor: + """ + Reshape a tensor safely. + """ + if tensor is None: + return None + return tensor.reshape(shape) diff --git a/vllm_omni/model_executor/stage_configs/__init__.py b/vllm_omni/model_executor/stage_configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/model_executor/stage_configs/bagel.yaml b/vllm_omni/model_executor/stage_configs/bagel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c60f3b7008d8f25843153f0e45cd0b25e380a1ea --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/bagel.yaml @@ -0,0 +1,82 @@ +# Stage 0: Thinker (multimodal understanding + text generation) + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: BagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.35 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished #or special token generated + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.55 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + +# Runtime edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + # Distributed connectors configuration (optional) + # More connectors will be supported in the future. + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold + + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c19302c765e62c0902a3ca78e9f487f741d716b9 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml @@ -0,0 +1,94 @@ +# Stage 0: Thinker (multimodal understanding + text generation) + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: BagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.35 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished #or special token generated + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: True + repetition_penalty: 1.05 + output_connectors: + to_stage_1: mooncake_connector + + + - stage_id: 1 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.55 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + input_connectors: + from_stage_0: mooncake_connector + + +# Runtime edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + # Distributed connectors configuration (optional) + # More connectors will be supported in the future. + connectors: + # Mooncake connector for cross-node/intra-node communication + mooncake_connector: + name: MooncakeConnector + extra: + host: "127.0.0.1" + metadata_server: "http://10.90.67.86:8080/metadata" + master: "10.90.67.86:50051" + segment: 512000000 # 512MB + localbuf: 64000000 # 64MB + proto: "tcp" + + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6ed976607707a09b1ff634838d04a584438887f --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml @@ -0,0 +1,106 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. + +# The following config has been verified on 2x H100-80G GPU. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + async_scheduling: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68f5817a75d4105ebbd88923ff7338b04d78dfac --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml @@ -0,0 +1,140 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. + +# The following config has been verified on 1x H100-80G GPU. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + # Distributed connector configuration (optional) + output_connectors: + to_stage_1: mooncake_connector + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + # Distributed connector configuration (optional) + input_connectors: + from_stage_0: mooncake_connector + output_connectors: + to_stage_2: mooncake_connector + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.3 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + # Distributed connector configuration (optional) + input_connectors: + from_stage_1: mooncake_connector + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + + # Distributed connectors configuration (optional) + # More connectors will be supported in the future. + connectors: + # Mooncake connector for cross-node/intra-node communication + mooncake_connector: + name: MooncakeConnector + extra: + host: "127.0.0.1" + metadata_server: "http://10.90.67.86:8080/metadata" + master: "10.90.67.86:50051" + segment: 512000000 # 512MB + localbuf: 64000000 # 64MB + proto: "tcp" + + # Yuanrong connector for cross-node/intra-node communication + yuanrong_connector: + name: YuanrongConnector + extra: + host: "127.0.0.1" + port: "35000" + + # SharedMemory connector for intra-node communication + # Alternative SHM connector with different threshold + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold + + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3dcf940f4bde55097861d22327deb752c0a96cd --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml @@ -0,0 +1,101 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +async_chunk: false +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0" + max_batch_size: 64 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 64 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8cb67dd531c79defd14cd5b48e63dc8f544f27f --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml @@ -0,0 +1,101 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) +# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +async_chunk: true +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0" + max_batch_size: 64 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 64 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk + engine_input_source: [0] + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 2048 # TODO: The max_tokens of the async_chunk feature cannot exceed 2048. + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 10000 + hf_config_name: thinker_config + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0bc6e48e594026aa1e86c662eb5d2ec5c6cfdc0d --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml @@ -0,0 +1,143 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings -> 8-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes -> audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 1 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + # Distributed connector configuration + output_connectors: + to_stage_1: connector_of_mooncake + + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + # Distributed connector configuration + input_connectors: + from_stage_0: connector_of_mooncake + output_connectors: + to_stage_2: connector_of_mooncake + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + # Distributed connector configuration + input_connectors: + from_stage_1: connector_of_mooncake + +# Top-level runtime config: default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + # Distributed connectors configuration + connectors: + # Mooncake connector for cross-node/intra-node communication + connector_of_mooncake: + name: MooncakeConnector + extra: + host: "127.0.0.1" + metadata_server: "http://10.90.67.86:8080/metadata" + master: "10.90.67.86:50051" + segment: 512000000 # 512MB + localbuf: 64000000 # 64MB + proto: "tcp" + + # SharedMemory connector for intra-node communication + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold + + edges: + - from: 0 + to: 1 + window_size: -1 + - from: 1 + to: 2 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d408dbab91e455ebe56dbb60552be3922d9d25e7 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -0,0 +1,22 @@ +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: qwen3_tts + model_arch: Qwen3TTSForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + + final_output: true + final_output_type: audio diff --git a/vllm_omni/model_executor/stage_input_processors/__init__.py b/vllm_omni/model_executor/stage_input_processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..e994589c4dd0ae3319bc8f9b4ea42b8bd945b21f --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -0,0 +1,61 @@ +import torch +from vllm.inputs import TextPrompt + +from vllm_omni.inputs.data import OmniTokensPrompt + +TALKER_CODEC_PAD_TOKEN_ID = 8292 +TALKER_CODEC_START_TOKEN_ID = 8293 +TALKER_CODEC_END_TOKEN_ID = 8294 + + +def thinker2talker( + stage_list, + engine_input_source, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +): + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid stage_id: {source_stage_id}") + if stage_list[source_stage_id].engine_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + thinker_outputs = stage_list[source_stage_id].engine_outputs + talker_inputs = [] + if not isinstance(prompt, list): + prompt = [prompt] + multi_modal_data = { + thinker_output.request_id: p.get("multi_modal_data", None) for thinker_output, p in zip(thinker_outputs, prompt) + } + + for i, thinker_output in enumerate(thinker_outputs): + output = thinker_output.outputs[0] + prompt_token_ids = thinker_output.prompt_token_ids + thinker_output_ids = output.token_ids + prompt_token_ids_len = len(prompt_token_ids) + latent = output.multimodal_output["latent"] + thinker_hidden_states = latent.clone().detach().to(latent.device) + additional_information = { + "thinker_result": thinker_hidden_states[prompt_token_ids_len:].to(torch.float32), + "prompt_embeds": thinker_hidden_states[:prompt_token_ids_len].to(torch.float32), + "prompt_token_ids": prompt_token_ids, + "thinker_output_token_ids": thinker_output_ids, + "thinker_result_shape": list(thinker_hidden_states[prompt_token_ids_len:].shape), + "prompt_embeds_shape": list(thinker_hidden_states[:prompt_token_ids_len].shape), + } + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[TALKER_CODEC_START_TOKEN_ID] + + [TALKER_CODEC_PAD_TOKEN_ID] * (len(prompt_token_ids)) + + [TALKER_CODEC_END_TOKEN_ID], + additional_information=additional_information, + multi_modal_data=( + multi_modal_data[thinker_output.request_id] + if requires_multimodal_data and multi_modal_data is not None + else None + ), + mm_processor_kwargs=None, + ) + ) + return talker_inputs diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..03daa8e42f004ed151f2d32068a6673c76f28144 --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +"""Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition.""" + +from typing import Any + +import torch +from vllm.inputs import TextPrompt +from vllm.platforms import current_platform + +from vllm_omni.engine import OmniEngineCoreRequest +from vllm_omni.inputs.data import OmniTokensPrompt + + +def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int: + im_start_token_id = 151644 + system_token_id = 8948 + user_token_id = 872 + assistant_token_id = 77091 + + thinker_sequences = torch.tensor(info["thinker_sequences"], dtype=torch.long, device=device).unsqueeze(0) # [1, T] + + input_ids = torch.tensor(info["thinker_input_ids"], dtype=torch.long, device=device).unsqueeze(0) # [1, T] + + im_start_indexes = torch.cat( + [ + torch.nonzero(input_ids[0] == im_start_token_id).squeeze(1), + torch.tensor([thinker_sequences.shape[-1]], device=input_ids.device, dtype=input_ids.dtype), + ], + dim=0, + ) + + sum_user_len = 0 + assistant_len = 0 + for i in range(len(im_start_indexes) - 1): + s = int(im_start_indexes[i].item()) + e = int(im_start_indexes[i + 1].item()) + role = int(input_ids[0, s + 1].item()) + if role == system_token_id: + continue + elif role == user_token_id: + sum_user_len += e - s + elif role == assistant_token_id and i == len(im_start_indexes) - 2: + assistant_len += 9 # 3 + 4 + 1 + 1 + else: + pass + + return sum_user_len + assistant_len + + +# ========================= +# Common helpers +# ========================= + + +def _ensure_list(x): + """Convert ConstantList / tensor-like to Python list.""" + if hasattr(x, "_x"): + return list(x._x) + elif not isinstance(x, list): + return x + return list(x) + + +def _validate_stage_inputs(stage_list, engine_input_source): + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + stage_id = engine_input_source[0] + if stage_id >= len(stage_list): + raise IndexError(f"Invalid stage_id: {stage_id}") + + stage = stage_list[stage_id] + if stage.engine_outputs is None: + raise RuntimeError(f"Stage {stage_id} has no outputs yet") + + return stage.engine_outputs + + +# ========================= +# Thinker -> Talker +# ========================= + + +def thinker2talker_async_chunk( + pooling_output: dict[str, Any], + request: OmniEngineCoreRequest, +) -> list[dict[str, Any]]: + """ + Process thinker outputs to create talker inputs. + 1. thinker's text generation outputs (token IDs + hidden states) + 2. Split hidden states into: prompt embeddings + generated embeddings + 3. Package for talker with additional information + """ + all_token_ids = request.all_token_ids # prefill + decode + prompt_token_ids = request.prompt_token_ids + + # Convert ConstantList to regular list for OmniSerializer serialization + all_token_ids = _ensure_list(all_token_ids) + prompt_token_ids = _ensure_list(prompt_token_ids) + + thinker_output = pooling_output + + talker_additional_info = { + "thinker_embeddings": thinker_output.get("0").detach().cpu(), + "thinker_hidden_states": thinker_output.get("24").detach().cpu(), + "thinker_sequences": all_token_ids, + "thinker_input_ids": prompt_token_ids, + # Provide thinker-side TTS token embeddings for talker projection + "tts_bos_embed": thinker_output.get("tts_bos_embed").detach().cpu(), + "tts_eos_embed": thinker_output.get("tts_eos_embed").detach().cpu(), + "tts_pad_embed": thinker_output.get("tts_pad_embed").detach().cpu(), + "finished": torch.tensor(request.is_finished(), dtype=torch.bool), + } + + return talker_additional_info + + +def thinker2talker( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | None = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """ + Process thinker outputs to create talker inputs. + + Workflow: + 1. Extract thinker's text generation outputs (token IDs + hidden states) + 2. Split hidden states into: prompt embeddings + generated embeddings + 3. Package for talker with additional information + + Args: + stage_list: List of stage objects + engine_input_source: Source stage IDs (typically [0] for thinker) + prompt: Original prompt data + requires_multimodal_data: Whether multimodal data is required + + Returns: + List of OmniTokensPrompt for talker stage + """ + thinker_outputs = _validate_stage_inputs(stage_list, engine_input_source) + talker_inputs: list[OmniTokensPrompt] = [] + + device = torch.device(current_platform.device_type) + + # Process each thinker output + for thinker_output in thinker_outputs: + output = thinker_output.outputs[0] + + info = { + "thinker_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float), + "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float), + "thinker_sequences": ( + thinker_output.prompt_token_ids + output.token_ids + ), # the thinker_sequences is the whole ids + "thinker_input_ids": thinker_output.prompt_token_ids, + # Provide thinker-side TTS token embeddings for talker projection + "tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float), + "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float), + "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float), + } + + prompt_len = _compute_talker_prompt_ids_length(info, device=device) + + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * prompt_len, + additional_information=info, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return talker_inputs + + +# ========================= +# Talker -> Code2Wav +# ========================= + + +def talker2code2wav_async_chunk( + pooling_output: dict[str, Any], + request: OmniEngineCoreRequest, +): + """ + Pooling version. + """ + if "code_predictor_codes" not in pooling_output: + return [] + + code_predictor_codes = pooling_output["code_predictor_codes"] + + if code_predictor_codes is None: + return [] + if isinstance(code_predictor_codes, torch.Tensor): + if code_predictor_codes.numel() == 0: + return [] + elif hasattr(code_predictor_codes, "__len__"): + if len(code_predictor_codes) == 0: + return [] + + if isinstance(code_predictor_codes, torch.Tensor): + if not code_predictor_codes.any(): + return [] + else: + code_tensor = torch.tensor(code_predictor_codes, dtype=torch.long) + if not code_tensor.any(): + return [] + + codec_codes = code_predictor_codes.to(torch.long).transpose(0, 1).cpu().to(torch.long).reshape(-1).tolist() + if sum(codec_codes) == 0: + return [] + + return { + "code_predictor_codes": codec_codes, + "finished": torch.tensor(request.is_finished(), dtype=torch.bool), + } + + +def talker2code2wav( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | None = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """ + Process talker outputs to create code2wav inputs. + + Workflow: + 1. Extract talker's codec code outputs (8-layer RVQ codes) + 2. Flatten codes for code2wav input + 3. Package for code2wav stage + + Args: + stage_list: List of stage objects + engine_input_source: Source stage IDs (typically [1] for talker) + prompt: Original prompt data + requires_multimodal_data: Whether multimodal data is required + + Returns: + List of OmniTokensPrompt for code2wav stage + """ + talker_outputs = _validate_stage_inputs(stage_list, engine_input_source) + code2wav_inputs: list[OmniTokensPrompt] = [] + # Process each talker output + for talker_output in talker_outputs: + output = talker_output.outputs[0] + seq_len = len(output.token_ids) - 1 + # Extract codec codes from talker output + # Expected shape: [8, seq_len] (8-layer RVQ codes) + codec_codes = ( + output.multimodal_output["code_predictor_codes"][-seq_len:] + .to(torch.long) + .transpose(0, 1) + .cpu() + .to(torch.long) + .reshape(-1) + .tolist() + ) # 16, seq_len + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=codec_codes, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return code2wav_inputs diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..f981e1edd5cf37bf1ec61a9b0a555ef9ce6f3baa --- /dev/null +++ b/vllm_omni/outputs.py @@ -0,0 +1,253 @@ +from dataclasses import dataclass, field +from typing import Any + +import torch +from PIL import Image +from vllm.outputs import RequestOutput +from vllm.v1.outputs import ModelRunnerOutput + +from vllm_omni.inputs.data import OmniPromptType + + +class OmniModelRunnerOutput(ModelRunnerOutput): + """Model runner output for omni models. + + Extends the base ModelRunnerOutput with support for multimodal outputs + that may be produced by non-autoregressive stages. + + Attributes: + multimodal_outputs: Optional dictionary mapping modality names to + output tensors (e.g., {"image": tensor, "audio": tensor}) + """ + + multimodal_outputs: dict[str, torch.Tensor] | None = None + # IDs of requests whose KV cache has been extracted from GPU/NPU to CPU. + # The Scheduler can safely free the block tables for these requests. + kv_extracted_req_ids: list[str] | None = None + + +@dataclass +class OmniRequestOutput: + """Unified request output for both pipeline stages and diffusion models. + + This class handles outputs from: + 1. Multi-stage LLM pipelines (with stage_id, final_output_type, request_output) + 2. Diffusion models (with images, prompt, metrics) + + Attributes: + request_id: Unique identifier for this request + finished: Whether generation is complete + stage_id: Identifier of the stage that produced this output (pipeline mode) + final_output_type: Type of output ("text", "image", "audio", "latents") + request_output: The underlying RequestOutput from the stage (pipeline mode) + images: List of generated PIL images (diffusion mode) + prompt: The prompt used for generation (diffusion mode) + latents: Optional tensor of latent representations (diffusion mode) + metrics: Optional dictionary of generation metrics + """ + + request_id: str = "" + finished: bool = True + + # Pipeline stage fields + stage_id: int | None = None + final_output_type: str = "text" + request_output: RequestOutput | None = None + + # Diffusion model fields + images: list[Image.Image] = field(default_factory=list) + prompt: OmniPromptType | None = None + latents: torch.Tensor | None = None + metrics: dict[str, Any] = field(default_factory=dict) + _multimodal_output: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_pipeline( + cls, + stage_id: int, + final_output_type: str, + request_output: RequestOutput, + ) -> "OmniRequestOutput": + """Create output from pipeline stage. + + Args: + stage_id: Stage identifier + final_output_type: Type of output + request_output: The stage's output + + Returns: + OmniRequestOutput configured for pipeline mode + """ + return cls( + request_id=getattr(request_output, "request_id", ""), + stage_id=stage_id, + final_output_type=final_output_type, + request_output=request_output, + finished=True, + ) + + @classmethod + def from_diffusion( + cls, + request_id: str, + images: list[Image.Image], + prompt: OmniPromptType | None = None, + metrics: dict[str, Any] | None = None, + latents: torch.Tensor | None = None, + multimodal_output: dict[str, Any] | None = None, + final_output_type: str = "image", + ) -> "OmniRequestOutput": + """Create output from diffusion model. + + Args: + request_id: Request identifier + images: Generated images + prompt: The prompt used + metrics: Generation metrics + latents: Optional latent tensors + + Returns: + OmniRequestOutput configured for diffusion mode + """ + return cls( + request_id=request_id, + final_output_type=final_output_type, + images=images, + prompt=prompt, + latents=latents, + metrics=metrics or {}, + _multimodal_output=multimodal_output or {}, + finished=True, + ) + + @property + def multimodal_output(self) -> dict[str, Any]: + """Return multimodal output from the underlying request output or local field. + + For pipeline outputs, this checks completion outputs first, then + request_output.multimodal_output. + For diffusion outputs, this returns the local _multimodal_output field. + """ + if self.request_output is not None: + # CompletionOutput is where the output processor attaches audio/image + # tensors for pipeline requests. + for output in getattr(self.request_output, "outputs", []): + mm_output = getattr(output, "multimodal_output", None) + if mm_output: + return mm_output + return getattr(self.request_output, "multimodal_output", {}) + return self._multimodal_output + + @property + def num_images(self) -> int: + """Return the number of generated images.""" + return len(self.images) + + # Pass-through properties keep vLLM serving codepaths compatible with + # OmniRequestOutput for pipeline outputs (Issue #345). + @property + def prompt_token_ids(self) -> list[int] | None: + """Return prompt token IDs from the underlying request output. + + This property is required for compatibility with vLLM's streaming + chat completion generator which checks res.prompt_token_ids. + """ + if self.request_output is not None: + return getattr(self.request_output, "prompt_token_ids", None) + return None + + @property + def outputs(self) -> list[Any]: + """Return outputs from the underlying request output. + + This property is required for compatibility with vLLM's streaming + and non-streaming chat completion generators. + """ + if self.request_output is not None: + return getattr(self.request_output, "outputs", []) + return [] + + @property + def encoder_prompt_token_ids(self) -> list[int] | None: + """Return encoder prompt token IDs from the underlying request output.""" + if self.request_output is not None: + return getattr(self.request_output, "encoder_prompt_token_ids", None) + return None + + @property + def prompt_logprobs(self) -> Any: + """Return prompt logprobs from the underlying request output.""" + if self.request_output is not None: + return getattr(self.request_output, "prompt_logprobs", None) + return None + + @property + def num_cached_tokens(self) -> int | None: + """Return number of cached tokens from the underlying request output.""" + if self.request_output is not None: + return getattr(self.request_output, "num_cached_tokens", None) + return None + + @property + def kv_transfer_params(self) -> Any: + """Return KV transfer params from the underlying request output.""" + if self.request_output is not None: + return getattr(self.request_output, "kv_transfer_params", None) + return None + + @property + def is_diffusion_output(self) -> bool: + """Check if this is a diffusion model output.""" + return len(self.images) > 0 or self.final_output_type == "image" + + @property + def is_pipeline_output(self) -> bool: + """Check if this is a pipeline stage output.""" + return self.stage_id is not None and self.request_output is not None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result = { + "request_id": self.request_id, + "finished": self.finished, + "final_output_type": self.final_output_type, + } + + if self.is_diffusion_output: + result.update( + { + "num_images": self.num_images, + "prompt": self.prompt, + "metrics": self.metrics, + } + ) + + if self.is_pipeline_output: + result.update( + { + "stage_id": self.stage_id, + } + ) + + return result + + def __repr__(self) -> str: + """Custom repr to properly show image count instead of image objects.""" + # For images, show count instead of full list + images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]" + + # Build repr string + parts = [ + f"request_id={self.request_id!r}", + f"finished={self.finished}", + f"stage_id={self.stage_id}", + f"final_output_type={self.final_output_type!r}", + f"request_output={self.request_output}", + f"images={images_repr}", + f"prompt={self.prompt!r}", + f"latents={self.latents}", + f"metrics={self.metrics}", + f"multimodal_output={self._multimodal_output}", + ] + + return f"OmniRequestOutput({', '.join(parts)})" diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..90c718d40ddfe90d7369bc9d16119531f3449030 --- /dev/null +++ b/vllm_omni/patch.py @@ -0,0 +1,33 @@ +import sys + +from vllm.inputs.data import TokensPrompt as _OriginalTokensPrompt +from vllm.model_executor.layers.rotary_embedding import ( + MRotaryEmbedding as _OriginalMRotaryEmbedding, +) +from vllm.v1.engine import EngineCoreOutput as _OriginalEngineCoreOutput +from vllm.v1.engine import EngineCoreOutputs as _OriginalEngineCoreOutputs +from vllm.v1.engine import EngineCoreRequest as _OriginalEngineCoreRequest +from vllm.v1.request import Request as _OriginalRequest + +import vllm_omni.logger # noqa: F401 +from vllm_omni.engine import OmniEngineCoreOutput, OmniEngineCoreOutputs, OmniEngineCoreRequest +from vllm_omni.inputs.data import OmniTokensPrompt +from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding +from vllm_omni.request import OmniRequest + +for module_name, module in sys.modules.items(): + # only do patch on module of vllm, pass others + if "vllm" not in module_name: + continue + if hasattr(module, "EngineCoreOutput") and module.EngineCoreOutput == _OriginalEngineCoreOutput: + module.EngineCoreOutput = OmniEngineCoreOutput + if hasattr(module, "EngineCoreOutputs") and module.EngineCoreOutputs == _OriginalEngineCoreOutputs: + module.EngineCoreOutputs = OmniEngineCoreOutputs + if hasattr(module, "TokensPrompt") and module.TokensPrompt == _OriginalTokensPrompt: + module.TokensPrompt = OmniTokensPrompt + if hasattr(module, "MRotaryEmbedding") and module.MRotaryEmbedding == _OriginalMRotaryEmbedding: + module.MRotaryEmbedding = OmniMRotaryEmbedding + if hasattr(module, "Request") and module.Request == _OriginalRequest: + module.Request = OmniRequest + if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest: + module.EngineCoreRequest = OmniEngineCoreRequest diff --git a/vllm_omni/platforms/__init__.py b/vllm_omni/platforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b76c59a67858530d6c289ff3fa90d292439bdb74 --- /dev/null +++ b/vllm_omni/platforms/__init__.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +import traceback +from itertools import chain +from typing import TYPE_CHECKING + +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import supports_xccl + +from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum +from vllm_omni.plugins import ( + OMNI_PLATFORM_PLUGINS_GROUP, + load_omni_plugins_by_group, +) + +logger = logging.getLogger(__name__) + + +def cuda_omni_platform_plugin() -> str | None: + """Check if CUDA OmniPlatform should be activated.""" + is_cuda = False + logger.debug("Checking if CUDA OmniPlatform is available.") + try: + from vllm.utils.import_utils import import_pynvml + + pynvml = import_pynvml() + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + is_cuda = True + logger.debug("Confirmed CUDA OmniPlatform is available.") + else: + logger.debug("CUDA OmniPlatform is not available because no GPU is found.") + finally: + pynvml.nvmlShutdown() + except Exception as e: + logger.debug("CUDA OmniPlatform is not available because: %s", str(e)) + + return "vllm_omni.platforms.cuda.platform.CudaOmniPlatform" if is_cuda else None + + +def rocm_omni_platform_plugin() -> str | None: + """Check if ROCm OmniPlatform should be activated.""" + is_rocm = False + logger.debug("Checking if ROCm OmniPlatform is available.") + try: + import amdsmi + + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + logger.debug("Confirmed ROCm OmniPlatform is available.") + else: + logger.debug("ROCm OmniPlatform is not available because no GPU is found.") + finally: + amdsmi.amdsmi_shut_down() + except Exception as e: + logger.debug("ROCm OmniPlatform is not available because: %s", str(e)) + + return "vllm_omni.platforms.rocm.platform.RocmOmniPlatform" if is_rocm else None + + +def npu_omni_platform_plugin() -> str | None: + """Check if NPU OmniPlatform should be activated.""" + is_npu = False + logger.debug("Checking if NPU OmniPlatform is available.") + try: + import torch + + if hasattr(torch, "npu") and torch.npu.is_available(): + is_npu = True + logger.debug("Confirmed NPU OmniPlatform is available.") + except Exception as e: + logger.debug("NPU OmniPlatform is not available because: %s", str(e)) + + return "vllm_omni.platforms.npu.platform.NPUOmniPlatform" if is_npu else None + + +def xpu_omni_platform_plugin() -> str | None: + """Check if XPU OmniPlatform should be activated.""" + is_xpu = False + logger.debug("Checking if XPU OmniPlatform is available.") + try: + # installed IPEX if the machine has XPUs. + import intel_extension_for_pytorch # noqa: F401 + import torch + + if supports_xccl(): + dist_backend = "xccl" + else: + dist_backend = "ccl" + import oneccl_bindings_for_pytorch # noqa: F401 + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + is_xpu = True + from vllm_omni.platforms.xpu import XPUOmniPlatform + + XPUOmniPlatform.dist_backend = dist_backend + logger.debug("Confirmed %s backend is available.", XPUOmniPlatform.dist_backend) + logger.debug("Confirmed XPU platform is available.") + except Exception as e: + logger.debug("XPU omni platform is not available because: %s", str(e)) + + return "vllm_omni.platforms.xpu.platform.XPUOmniPlatform" if is_xpu else None + + +builtin_omni_platform_plugins = { + "cuda": cuda_omni_platform_plugin, + "rocm": rocm_omni_platform_plugin, + "npu": npu_omni_platform_plugin, + "xpu": xpu_omni_platform_plugin, +} + + +def resolve_current_omni_platform_cls_qualname() -> str: + """Resolve the current OmniPlatform class qualified name.""" + platform_plugins = load_omni_plugins_by_group(OMNI_PLATFORM_PLUGINS_GROUP) + + activated_plugins = [] + + for name, func in chain(builtin_omni_platform_plugins.items(), platform_plugins.items()): + try: + assert callable(func) + platform_cls_qualname = func() + if platform_cls_qualname is not None: + activated_plugins.append(name) + except Exception: + pass + + activated_builtin_plugins = list(set(activated_plugins) & set(builtin_omni_platform_plugins.keys())) + activated_oot_plugins = list(set(activated_plugins) & set(platform_plugins.keys())) + + if len(activated_oot_plugins) >= 2: + raise RuntimeError(f"Only one OmniPlatform plugin can be activated, but got: {activated_oot_plugins}") + elif len(activated_oot_plugins) == 1: + platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]() + logger.info("OmniPlatform plugin %s is activated", activated_oot_plugins[0]) + elif len(activated_builtin_plugins) >= 2: + raise RuntimeError(f"Only one OmniPlatform plugin can be activated, but got: {activated_builtin_plugins}") + elif len(activated_builtin_plugins) == 1: + platform_cls_qualname = builtin_omni_platform_plugins[activated_builtin_plugins[0]]() + logger.debug("Automatically detected OmniPlatform %s.", activated_builtin_plugins[0]) + else: + platform_cls_qualname = "vllm_omni.platforms.interface.UnspecifiedOmniPlatform" + logger.debug("No platform detected, vLLM-Omni is running on UnspecifiedOmniPlatform") + + return platform_cls_qualname + + +_current_omni_platform = None +_init_trace: str = "" + +if TYPE_CHECKING: + current_omni_platform: OmniPlatform + + +def __getattr__(name: str): + if name == "current_omni_platform": + # Lazy init current_omni_platform + global _current_omni_platform + if _current_omni_platform is None: + platform_cls_qualname = resolve_current_omni_platform_cls_qualname() + _current_omni_platform = resolve_obj_by_qualname(platform_cls_qualname)() + global _init_trace + _init_trace = "".join(traceback.format_stack()) + return _current_omni_platform + elif name in globals(): + return globals()[name] + else: + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + + +def __setattr__(name: str, value): # noqa: N807 + if name == "current_omni_platform": + global _current_omni_platform + _current_omni_platform = value + elif name in globals(): + globals()[name] = value + else: + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + + +__all__ = [ + "OmniPlatform", + "OmniPlatformEnum", + "current_omni_platform", + "_init_trace", +] diff --git a/vllm_omni/platforms/cuda/__init__.py b/vllm_omni/platforms/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91ef864be2bc3650efe1835e233a84e9d9938d48 --- /dev/null +++ b/vllm_omni/platforms/cuda/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.platforms.cuda.platform import CudaOmniPlatform + +__all__ = ["CudaOmniPlatform"] diff --git a/vllm_omni/platforms/cuda/platform.py b/vllm_omni/platforms/cuda/platform.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf740a01883bedff14064eee6f76d8d746517a6 --- /dev/null +++ b/vllm_omni/platforms/cuda/platform.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger +from vllm.platforms.cuda import CudaPlatformBase +from vllm.platforms.interface import DeviceCapability + +from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum +from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum + +logger = init_logger(__name__) + + +class CudaOmniPlatform(OmniPlatform, CudaPlatformBase): + """CUDA/GPU implementation of OmniPlatform (default). + + Inherits all CUDA-specific implementations from vLLM's CudaPlatform, + and adds Omni-specific interfaces from OmniPlatform. + """ + + _omni_enum = OmniPlatformEnum.CUDA + + @classmethod + def get_omni_ar_worker_cls(cls) -> str: + return "vllm_omni.worker.gpu_ar_worker.GPUARWorker" + + @classmethod + def get_omni_generation_worker_cls(cls) -> str: + return "vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker" + + @classmethod + def get_default_stage_config_path(cls) -> str: + return "vllm_omni/model_executor/stage_configs" + + @classmethod + def get_diffusion_attn_backend_cls( + cls, + selected_backend: str | None, + head_size: int, + ) -> str: + from vllm_omni.diffusion.envs import PACKAGES_CHECKER + + # Check compute capability for Flash Attention support + # Flash Attention requires compute capability >= 8.0 and < 10.0 + compute_capability = cls.get_device_capability() + compute_supported = False + if compute_capability is not None: + major, minor = compute_capability + capability = major * 10 + minor + compute_supported = 80 <= capability < 100 + + # Check if FA packages are available + packages_info = PACKAGES_CHECKER.get_packages_info() + packages_available = packages_info.get("has_flash_attn", False) + + # Both compute capability and packages must be available for FA + flash_attn_supported = compute_supported and packages_available + + if selected_backend is not None: + backend_upper = selected_backend.upper() + if backend_upper == "FLASH_ATTN" and not flash_attn_supported: + if not compute_supported: + logger.warning( + "Flash Attention requires GPU with compute capability >= 8.0 " + "and < 10.0. Falling back to TORCH_SDPA backend." + ) + elif not packages_available: + logger.warning("Flash Attention packages not available. Falling back to TORCH_SDPA backend.") + logger.info("Defaulting to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + backend = DiffusionAttentionBackendEnum[backend_upper] + logger.info("Using diffusion attention backend '%s'", backend_upper) + return backend.get_path() + + if flash_attn_supported: + logger.info("Defaulting to diffusion attention backend FLASH_ATTN") + return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path() + + logger.info("Defaulting to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + + @classmethod + def supports_torch_inductor(cls) -> bool: + return True + + @classmethod + def get_torch_device(cls, local_rank: int | None = None) -> torch.device: + if local_rank is None: + return torch.device("cuda") + return torch.device("cuda", local_rank) + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_count(cls) -> int: + return torch.cuda.device_count() + + @classmethod + def get_device_version(cls) -> str | None: + return torch.version.cuda + + @classmethod + def synchronize(cls) -> None: + torch.cuda.synchronize() + + @classmethod + def get_free_memory(cls, device: torch.device | None = None) -> int: + free, _ = torch.cuda.mem_get_info(device) + return free + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a0422406389a8881ca9b98c72fcf6d1382a50d78 --- /dev/null +++ b/vllm_omni/platforms/interface.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum + +import torch +from vllm.platforms import Platform + + +class OmniPlatformEnum(Enum): + """Enum for supported Omni platforms.""" + + CUDA = "cuda" + ROCM = "rocm" + NPU = "npu" + XPU = "xpu" + UNSPECIFIED = "unspecified" + + +class OmniPlatform(Platform): + """ + Abstract base class for vllm-omni Platform. + + Inherits from vLLM's Platform and adds Omni-specific interfaces. + This gives OmniPlatform all vLLM Platform capabilities plus + Omni-specific methods. + """ + + _omni_enum: OmniPlatformEnum + + def is_npu(self) -> bool: + return self._omni_enum == OmniPlatformEnum.NPU + + def is_xpu(self) -> bool: + return self._omni_enum == OmniPlatformEnum.XPU + + def is_cuda(self) -> bool: + return self._omni_enum == OmniPlatformEnum.CUDA + + def is_rocm(self) -> bool: + return self._omni_enum == OmniPlatformEnum.ROCM + + @classmethod + def get_omni_ar_worker_cls(cls) -> str: + raise NotImplementedError + + @classmethod + def get_omni_generation_worker_cls(cls) -> str: + raise NotImplementedError + + @classmethod + def get_default_stage_config_path(cls) -> str: + raise NotImplementedError + + @classmethod + def get_diffusion_attn_backend_cls( + cls, + selected_backend: str | None, + head_size: int, + ) -> str: + """Get the diffusion attention backend class path for this platform. + + This method selects the appropriate attention backend for diffusion + models based on platform capabilities and user preferences. + + Args: + selected_backend: User-selected backend name (e.g., "FLASH_ATTN", + "TORCH_SDPA", "SAGE_ATTN"). If None, uses platform default. + head_size: Attention head size. + + Returns: + Fully qualified class path of the selected backend. + """ + raise NotImplementedError + + @classmethod + def supports_torch_inductor(cls) -> bool: + """Check if the platform supports torch.compile with inductor backend.""" + raise NotImplementedError + + @classmethod + def get_torch_device(cls, local_rank: int | None = None) -> torch.device: + raise NotImplementedError + + @classmethod + def get_device_count(cls) -> int: + raise NotImplementedError + + @classmethod + def get_device_version(cls) -> str | None: + raise NotImplementedError + + @classmethod + def synchronize(cls) -> None: + raise NotImplementedError + + @classmethod + def get_free_memory(cls, device: torch.device | None = None) -> int: + raise NotImplementedError + + +class UnspecifiedOmniPlatform(OmniPlatform): + _omni_enum = OmniPlatformEnum.UNSPECIFIED + device_type = "" diff --git a/vllm_omni/platforms/npu/__init__.py b/vllm_omni/platforms/npu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb95f7b29c5134034098707124027e777029a868 --- /dev/null +++ b/vllm_omni/platforms/npu/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.platforms.npu.platform import NPUOmniPlatform + +__all__ = ["NPUOmniPlatform"] diff --git a/vllm_omni/platforms/npu/platform.py b/vllm_omni/platforms/npu/platform.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2495c3d35a61a35f1a712386c220b34e23e944 --- /dev/null +++ b/vllm_omni/platforms/npu/platform.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger +from vllm_ascend.platform import NPUPlatform + +from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum +from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum + +logger = init_logger(__name__) + + +class NPUOmniPlatform(OmniPlatform, NPUPlatform): + """NPU/Ascend implementation of OmniPlatform. + + Inherits all NPU-specific implementations from vllm-ascend's NPUPlatform, + and adds Omni-specific interfaces from OmniPlatform. + """ + + _omni_enum = OmniPlatformEnum.NPU + dist_backend: str = "hccl" + + @classmethod + def get_omni_ar_worker_cls(cls) -> str: + return "vllm_omni.platforms.npu.worker.npu_ar_worker.NPUARWorker" + + @classmethod + def get_omni_generation_worker_cls(cls) -> str: + return "vllm_omni.platforms.npu.worker.npu_generation_worker.NPUGenerationWorker" + + @classmethod + def get_default_stage_config_path(cls) -> str: + return "vllm_omni/platforms/npu/stage_configs" + + @classmethod + def get_diffusion_attn_backend_cls( + cls, + selected_backend: str | None, + head_size: int, + ) -> str: + from importlib.util import find_spec + + if selected_backend is not None: + backend_upper = selected_backend.upper() + backend = DiffusionAttentionBackendEnum[backend_upper] + logger.info("Using diffusion attention backend '%s'", backend_upper) + return backend.get_path() + + # Try FLASH_ATTN if mindiesd is available, otherwise fall back to SDPA + if find_spec("mindiesd"): + logger.info("Defaulting to diffusion attention backend FLASH_ATTN") + return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path() + + logger.info("Falling back to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + + @classmethod + def supports_torch_inductor(cls) -> bool: + return False + + @classmethod + def get_torch_device(cls, local_rank: int | None = None) -> torch.device: + if local_rank is None: + return torch.device("npu") + return torch.device("npu", local_rank) + + @classmethod + def get_device_count(cls) -> int: + return torch.npu.device_count() + + @classmethod + def get_device_version(cls) -> str | None: + return None + + @classmethod + def synchronize(cls) -> None: + torch.npu.synchronize() + + @classmethod + def get_free_memory(cls, device: torch.device | None = None) -> int: + free, _ = torch.npu.mem_get_info(device) + return free + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.npu.get_device_properties(device_id) + return device_props.total_memory diff --git a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d021ef218ee3de1a23bfafcc1bf41a13d25c7ce3 --- /dev/null +++ b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml @@ -0,0 +1,97 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true # haven't supported talker ACL graph on NPU + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "2" # Example: use a different NPU than the previous stage; use "0" if single NPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f99ed22e8789e4dc13b4262e5c2b921ab2032233 --- /dev/null +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml @@ -0,0 +1,96 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 5x A2/A3-64G NPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0,1,2,3" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 4 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "4" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.2 + enforce_eager: true # haven't supported talker ACL graph on NPU + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d408dbab91e455ebe56dbb60552be3922d9d25e7 --- /dev/null +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml @@ -0,0 +1,22 @@ +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: qwen3_tts + model_arch: Qwen3TTSForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + + final_output: true + final_output_type: audio diff --git a/vllm_omni/platforms/npu/worker/__init__.py b/vllm_omni/platforms/npu/worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf9e632f78c109dff7e199fee999cc4dc6ff320 --- /dev/null +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -0,0 +1,523 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from copy import copy +from typing import Any, NamedTuple + +import numpy as np +import torch +from vllm.config import CUDAGraphMode +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.forward_context import get_forward_context +from vllm.logger import logger +from vllm.sequence import IntermediateTensors +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + ECConnectorOutput, + make_empty_encoder_model_runner_output, +) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.structured_output.utils import apply_grammar_bitmask +from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState + +# yapf conflicts with isort for this block +# yapf: disable +from vllm_ascend.compilation.acl_graph import update_full_graph_params +from vllm_ascend.ops.rotary_embedding import update_cos_sin +from vllm_ascend.utils import ProfileExecuteDuration + +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner + + +class ExecuteModelState(NamedTuple): + """Ephemeral cached state transferred between execute_model() and + sample_tokens(), after execute_model() returns None.""" + + scheduler_output: SchedulerOutput + logits: torch.Tensor + spec_decode_metadata: SpecDecodeMetadata | None + hidden_states: torch.Tensor + sample_hidden_states: torch.Tensor + aux_hidden_states: list[torch.Tensor] | None + kv_connector_output: KVConnectorOutput | None + attn_metadata: dict[str, Any] + positions: torch.Tensor + ec_connector_output: ECConnectorOutput | None + multimodal_outputs: Any + +class NPUARModelRunner(OmniNPUModelRunner): + """Autoregressive NPU model runner that returns hidden states per request.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + # each model stage has their own hidden size + self.hidden_size = self.model_config.hf_text_config.hidden_size + self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) + # Initialize KV cache manager (preserve vllm_config fallback behavior) + self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) + + def _make_buffer(self, *size, dtype, numpy=True): + # Prevent ray from pinning the buffer due to large size + from vllm_omni.distributed.ray_utils.utils import ( + calculate_total_bytes, + maybe_disable_pin_memory_for_ray, + ) + + total_bytes = calculate_total_bytes(size, dtype) + + # Use the context manager to temporarily disable pinning if needed + with maybe_disable_pin_memory_for_ray(self, total_bytes): + return super()._make_buffer(*size, dtype=dtype, numpy=numpy) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: + if self.execute_model_state is not None: + raise RuntimeError("State error: sample_tokens() must be called " + "after execute_model() returns None.") + + # -------------------------------------- Omni-new ------------------------------------------------- + # [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) + self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer( + finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}), + kv_caches=self.kv_caches, + block_size=self.cache_config.block_size, + cache_dtype=str(self.cache_config.cache_dtype), + request_id_resolver=self._resolve_global_request_id, + ) + # -------------------------------------- Omni-new ------------------------------------------------- + + with ProfileExecuteDuration().capture_async("prepare input"): + # -------------------------------------- Omni-new ------------------------------------------------- + self._update_states(scheduler_output) + self._decode_and_store_request_payloads(scheduler_output) + # ------------------------------------------------------------------------------------------------ + + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output( + scheduler_output) + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + logger.debug( + "skip this step for we receive the data from remote disaggregate prefill node" + ) + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + (attn_metadata, num_scheduled_tokens_np, num_input_tokens, + num_tokens_across_dp, logits_indices, spec_decode_metadata, + max_query_len) = self._prepare_inputs(scheduler_output) + + (input_ids, inputs_embeds, positions, intermediate_tensors, + model_kwargs, ec_connector_output) = self._preprocess(scheduler_output, + num_input_tokens, + intermediate_tensors) + + # update global cos, sin + update_cos_sin(positions) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + # prevent debugger is None + if self.debugger is not None: + dbg_cfg = getattr(self.debugger, "config", None) + dump_level = str( + getattr(dbg_cfg, "level", + "L1")).upper() if dbg_cfg is not None else "L1" + if dump_level in ("L0", "MIX"): + self.debugger.start(model=self.model) + else: + self.debugger.start() + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len) + has_lora = len(self.input_batch.lora_id_to_lora_request) > 0 + aclgraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch( + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + has_lora=has_lora + ) + + if self.ascend_config.enable_async_exponential: + self.sampler.do_async_exponential( + b_s=logits_indices.shape[0], + head_dim=self.model_config.get_vocab_size(), + generators=self.input_batch.sampling_metadata.generators) + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=scheduler_output. + total_num_scheduled_tokens, + model_instance=self.model): + self.maybe_setup_kv_connector(scheduler_output) + + hidden_states = self._generate_process_reqs_hidden_states( + num_input_tokens, input_ids, positions, + intermediate_tensors, inputs_embeds, model_kwargs) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer( + scheduler_output) + + aux_hidden_states = None + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = hidden_states + + kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving) + finished_sending = None + finished_recving = None + with ProfileExecuteDuration().capture_async("post process"): + # -------------------------------------- Omni-new ------------------------------------------------- + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + + if multimodal_outputs is not None: + keys_or_type = ( + list(multimodal_outputs.keys()) + if isinstance(multimodal_outputs, dict) + else type(multimodal_outputs) + ) + logger.debug(f"[AR] execute_model: multimodal_outputs keys = {keys_or_type}") + else: + logger.debug("[AR] execute_model: multimodal_outputs is None") + # -------------------------------------- Omni-new ------------------------------------------------- + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output + if self.debugger is not None: + self.debugger.stop() + self.debugger.step() + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + pool_output = self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, kv_connector_output) + if self.debugger is not None: + self.debugger.stop() + self.debugger.step() + return pool_output + sample_hidden_states = hidden_states[logits_indices] + # -------------------------------------- Omni-new ------------------------------------------------- + # Try with sampling_metadata first; fall back to without for models that don't support it + try: + logits = self.model.compute_logits( + sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata + ) + except TypeError: + logits = self.model.compute_logits(sample_hidden_states) + # -------------------------------------- Omni-new ------------------------------------------------- + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group( + ).broadcast_tensor_dict(model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + self.execute_model_state = ExecuteModelState( + scheduler_output, + logits, + spec_decode_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ec_connector_output, + multimodal_outputs, # Omni-new + ) + self.kv_connector_output = kv_connector_output + return None + + @torch.inference_mode() + def sample_tokens( + self, + grammar_output: GrammarOutput | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + # -------------------------------------- Omni-new ------------------------------------------------- + kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None) + self.kv_extracted_req_ids = None + # -------------------------------------- Omni-new ------------------------------------------------- + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # noqa + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ec_connector_output, + multimodal_outputs, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + # here we are different from gpu_model_runner, + # the apply_grammar_bitmask uses torch.compile to optimize this,ascend does not support it now + logits_dtype = logits.dtype + logits = logits.to("cpu").float() + apply_grammar_bitmask(scheduler_output, grammar_output, + self.input_batch, logits) + logits = logits.to(self.device).to(logits_dtype) + + with ProfileExecuteDuration().capture_async("Sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + def propose_draft_token_ids(sampled_token_ids): + assert self.spec_decode_common_attn_metadata is not None + self._draft_token_ids = self.propose_draft_token_ids( + sampled_token_ids, + self.input_batch.sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + scheduler_output.total_num_scheduled_tokens, + hidden_states, + attn_metadata, + aux_hidden_states, + sample_hidden_states + ) + self._copy_draft_token_ids_to_cpu(scheduler_output) + + ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) + + with ProfileExecuteDuration().capture_async("Draft"): + if self.speculative_config: + use_padded_batch_for_eagle = self.speculative_config and \ + self.speculative_config.use_eagle() and \ + not self.speculative_config.disable_padded_drafter_batch + if use_padded_batch_for_eagle: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampler_output.sampled_token_ids) + if self.speculative_config and not use_padded_batch_for_eagle: + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) + + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + # -------------------------------------- Omni-new ------------------------------------------------- + hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() + num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None) + if num_scheduled_tokens_np is None: + req_ids = self.input_batch.req_ids + num_scheduled_tokens_np = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], + dtype=np.int32, + ) + + self._process_additional_information_updates( + hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + ) + + pooler_output: list[dict[str, object]] = [] + for rid in req_ids_output_copy: + idx = req_id_to_index_output_copy[rid] + start = int(self.query_start_loc.cpu[idx]) + sched = int(num_scheduled_tokens_np[idx]) + end = start + sched + hidden_slice = hidden_states_cpu[start:end] + payload: dict[str, object] = {"hidden": hidden_slice} + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + mm_payload: dict[str, object] = {} + for k, v in multimodal_outputs.items(): + try: + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + mm_payload[k] = v.detach().to("cpu")[start:end].contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: + sub_dict[str(sk)] = sv.detach().to("cpu")[start:end].contiguous() + if sub_dict: + mm_payload[k] = sub_dict + elif isinstance(v, list): + element = v[0] + if isinstance(element, torch.Tensor): + element = element.detach().to("cpu").contiguous() + mm_payload[k] = element + except Exception as e: + logger.error(f"Error in merge multimodal outputs: {e}") + if mm_payload: + payload.update(mm_payload) + pooler_output.append(payload) + + model_runner_output = OmniModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), + kv_connector_output=kv_connector_output, + ) + model_runner_output.kv_extracted_req_ids = kv_extracted_req_ids + # -------------------------------------- Omni-new ------------------------------------------------- + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [ + f"[{tag}]:{duration:.2f}ms" + for tag, duration in durations.items() + ] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, + " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + if not self.use_async_scheduling: + if self.debugger is not None: + assert self.debugger is not None + self.debugger.stop() + self.debugger.step() + return model_runner_output + + if self.debugger is not None: + assert self.debugger is not None + self.debugger.stop() + self.debugger.step() + return AsyncGPUModelRunnerOutput( + model_runner_output=model_runner_output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + ) + + def _generate_process_reqs_hidden_states(self, num_input_tokens, + input_ids, positions, + intermediate_tensors, + inputs_embeds, model_kwargs): + assert self.model is not None + hidden_states = self._model_forward(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs) + + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ + and not self.use_sparse: + if self.vllm_config.model_config.use_mla: + if self.pcp_size * self.dcp_size > 1: + update_full_graph_params(self.attn_backend, self.update_stream, forward_context, + num_input_tokens, self.vllm_config, + self.vllm_config.speculative_config) + + if get_forward_context().sp_enabled and not isinstance( + hidden_states, IntermediateTensors): + hidden_states = self._all_gather_hidden_states_and_aux( + hidden_states) + return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states( + hidden_states) + + def _resolve_global_request_id(self, req_id: str) -> str: + """Resolve global request ID from request state.""" + req_state = self.requests.get(req_id) + if not req_state: + return req_id + + add_info = getattr(req_state, "additional_information_cpu", {}) or {} + global_id = add_info.get("global_request_id") + if global_id: + if isinstance(global_id, list) and global_id: + global_id = global_id[0] + if isinstance(global_id, bytes): + return global_id.decode("utf-8") + return str(global_id) + return req_id diff --git a/vllm_omni/platforms/npu/worker/npu_ar_worker.py b/vllm_omni/platforms/npu/worker/npu_ar_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa3e59edc4822ef227b5c7c856f12b99d636733 --- /dev/null +++ b/vllm_omni/platforms/npu/worker/npu_ar_worker.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.worker.workspace import init_workspace_manager +from vllm_ascend.worker.worker import NPUWorker + +from vllm_omni.platforms.npu.worker.npu_ar_model_runner import NPUARModelRunner +from vllm_omni.worker.mixins import OmniWorkerMixin + + +class NPUARWorker(OmniWorkerMixin, NPUWorker): + """NPU AR worker for thinker/talker stages in Omni model.""" + + def init_device(self): + self.device = self._init_device() + num_ubatches = 1 + init_workspace_manager(self.device, num_ubatches) + + self.model_runner = NPUARModelRunner(self.vllm_config, self.device) diff --git a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..e8559bb463ca18e45668306129479ced1b2369cc --- /dev/null +++ b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py @@ -0,0 +1,544 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import gc +import math +from copy import copy + +import numpy as np +import torch +from vllm.config import CUDAGraphMode +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import has_kv_transfer_group +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import logger +from vllm.sequence import IntermediateTensors +from vllm.utils.math_utils import cdiv +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, make_empty_encoder_model_runner_output +from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm_ascend.ascend_forward_context import MoECommType, get_mc2_tokens_capacity, set_ascend_forward_context +from vllm_ascend.ops.rotary_embedding import update_cos_sin +from vllm_ascend.platform import NPUPlatform +from vllm_ascend.utils import ProfileExecuteDuration, enable_sp, lmhead_tp_enable + +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.platforms.npu.worker.npu_ar_model_runner import ExecuteModelState +from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner + + +class NPUGenerationModelRunner(OmniNPUModelRunner): + """Generation model runner for vLLM-omni on NPU (non-autoregressive).""" + + def _update_request_states(self, scheduler_output: SchedulerOutput): + cached_reqs = scheduler_output.scheduled_cached_reqs + for _, req_id in enumerate(cached_reqs.req_ids): + req_state = self.requests.get(req_id) + assert req_state is not None + req_state.prompt_token_ids = cached_reqs.prompt_token_ids.get(req_id) + self.input_batch.remove_request(req_id) + # update the request state in self.input_batch + self.input_batch.add_request(req_state) + self._init_mrope_positions(req_state) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> OmniModelRunnerOutput | IntermediateTensors: + if self.execute_model_state is not None: + raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") + + with ProfileExecuteDuration().capture_async("prepare input"): + # -------------------------------------- Omni-new ------------------------------------------------- + if self.model_config.async_chunk: + self._update_request_states(scheduler_output) + # -------------------------------------- Omni-new ------------------------------------------------- + self._update_states(scheduler_output) + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + logger.debug("skip this step for we receive the data from remote disaggregate prefill node") + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + ( + attn_metadata, + num_scheduled_tokens_np, + num_input_tokens, + num_tokens_across_dp, + logits_indices, + spec_decode_metadata, + max_query_len, + ) = self._prepare_inputs(scheduler_output) + + (input_ids, inputs_embeds, positions, intermediate_tensors, model_kwargs, ec_connector_output) = ( + self._preprocess(scheduler_output, num_input_tokens, intermediate_tensors) + ) + + # update global cos, sin + update_cos_sin(positions) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + # prevent debugger is None + if self.debugger is not None: + dbg_cfg = getattr(self.debugger, "config", None) + dump_level = str(getattr(dbg_cfg, "level", "L1")).upper() if dbg_cfg is not None else "L1" + if dump_level in ("L0", "MIX"): + self.debugger.start(model=self.model) + else: + self.debugger.start() + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + has_lora = len(self.input_batch.lora_id_to_lora_request) > 0 + aclgraph_runtime_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) + + if self.ascend_config.enable_async_exponential: + self.sampler.do_async_exponential( + b_s=logits_indices.shape[0], + head_dim=self.model_config.get_vocab_size(), + generators=self.input_batch.sampling_metadata.generators, + ) + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + model_instance=self.model, + ): + self.maybe_setup_kv_connector(scheduler_output) + # -------------------------------------- Omni-new ------------------------------------------------- + outputs = self._run_generation_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + model_kwargs=model_kwargs, + logits_indices=logits_indices, + ) + # -------------------------------------- Omni-new ------------------------------------------------- + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer(scheduler_output) + + aux_hidden_states = None + if self.use_aux_hidden_state_outputs: + outputs, aux_hidden_states = outputs + + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, finished_recving=finished_recving) + finished_sending = None + finished_recving = None + + _, multimodal_outputs = self.extract_multimodal_outputs(outputs) + # Apply structured output bitmasks if present + self.execute_model_state = ExecuteModelState( + scheduler_output, + None, + spec_decode_metadata, + outputs, + None, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ec_connector_output, + multimodal_outputs, + ) + self.kv_connector_output = kv_connector_output + return None + + @torch.inference_mode() + def sample_tokens( + self, + grammar_output: GrammarOutput | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # noqa + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + attn_metadata, + positions, + ec_connector_output, + multimodal_outputs, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # -------------------------------------- Omni-new ------------------------------------------------- + pooler_output: list[object] = [] + if isinstance(multimodal_outputs, torch.Tensor): + assert multimodal_outputs.shape[0] == 1, ( + "model should return a single tensor, to return multiple tensors, use a dict" + ) + assert multimodal_outputs.shape[0] == self.input_batch.num_reqs + for i in range(self.input_batch.num_reqs): + pooler_output.append({"model_outputs": multimodal_outputs[i].detach().to("cpu").contiguous()}) + elif isinstance(multimodal_outputs, list): + assert len(multimodal_outputs) == 1, ( + "model should return a single list, to return multiple lists, use a dict" + ) + for out in multimodal_outputs: + pooler_output.append( + {"model_outputs": out.detach().to("cpu").contiguous() if out is not None else None} + ) + elif isinstance(multimodal_outputs, dict): + mm_payload = {} + for key, out in multimodal_outputs.items(): + if out is not None and isinstance(out, torch.Tensor): + mm_payload[key] = out.detach().to("cpu").contiguous() + pooler_output.append(mm_payload) + else: + raise RuntimeError("Unsupported diffusion output type") + output = OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + num_nans_in_logits={}, + ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, + ) + # -------------------------------------- Omni-new ------------------------------------------------- + if not self.use_async_scheduling: + return output + return AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=torch.tensor([], device=self.device), + logprobs_tensors=None, + invalid_req_indices=[], + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + ) + + def _run_generation_model( + self, + *, + input_ids: torch.Tensor | None, + positions: torch.Tensor | None, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, + model_kwargs: dict, + logits_indices: torch.Tensor, + ) -> torch.Tensor | list[torch.Tensor]: + """Run generation from codec codes to waveforms. + + Args: + scheduler_output: Contains codec codes in input_ids or additional info + intermediate_tensors: PP intermediate tensors if applicable + + Returns: + Audio waveforms: [batch, 1, waveform_len] or list of tensors + """ + # Keep inputs identical to AR runner + kwargs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + if hasattr(self.model, "forward"): + return self._model_forward(**kwargs) + + raise RuntimeError( + "The loaded model does not expose diffusion interfaces 'sample', " + "'forward', or 'diffuse'. Please implement one of them or adapt the runner." + ) + + @torch.inference_mode() + def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: + logger.warning("Dummy sampler run is not implemented for generation model") + return None + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + is_profile: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, + ) -> torch.Tensor: + # only support eager mode and piecewise graph now + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + # Force dummy run on prefill stage when this node is deemed as kv producer. + if self.is_kv_producer and not self.is_kv_consumer: + with_prefill = True + + has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) + + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp( + batch_descriptor.num_tokens, with_prefill + ) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.max_num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + num_reqs = (num_tokens + self.decode_token_per_req - 1) // self.decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + + if not is_profile and self.dynamic_eplb: + self.eplb_updator.forward_before() + + if num_tokens != batch_descriptor.num_tokens: + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) + + num_tokens_padded = batch_descriptor.num_tokens + num_reqs_padded = batch_descriptor.num_reqs if batch_descriptor.num_reqs is not None else num_reqs + if num_tokens_across_dp is not None and num_tokens_padded != num_tokens: + # pad is needed if the pad of `num_tokens` is triggered inside CudagraphDispatcher + num_tokens_across_dp[:] = num_tokens_padded + num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) + + # filter out the valid batch descriptor + if cudagraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode: + raise ValueError( + f"Aclgraph runtime mode mismatch at dummy_run. " + f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}." + ) + else: + cudagraph_runtime_mode = _ag_mode + + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( + False, + num_reqs=num_reqs_padded, + num_tokens=num_tokens_padded, + max_query_len=max_query_len, + aclgraph_runtime_mode=cudagraph_runtime_mode, + force_attention=force_attention, + is_graph_capturing=is_graph_capturing, + num_scheduled_tokens=num_scheduled_tokens, + ) + + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, num_sampled_tokens): + # Make sure padding doesn't exceed max_num_tokens + assert num_tokens_padded <= self.max_num_tokens + if self.is_multimodal_model and not self.model_config.is_encoder_decoder: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + else: + input_ids = self.input_ids.gpu[:num_tokens_padded] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_tokens_padded] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] + else: + positions = self.positions.gpu[:num_tokens_padded] + + # update global cos, sin + update_cos_sin(positions) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + # When PP and flashcomm1 are enabled, + # during dummy_run the estimated space should divide num_tokens by tp_size; + # otherwise, on non-first PP ranks it would effectively perform an extra all-gather, + # leading to incorrect memory estimation and potentially causing OOM. + actual_tokens = num_tokens + if enable_sp(): + tp_size = get_tensor_model_parallel_world_size() + actual_tokens = num_tokens // tp_size + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=actual_tokens, dtype=self.dtype, device=self.device + ) + intermediate_tensors = IntermediateTensors( + {k: v[:num_tokens_padded] for k, v in self.intermediate_tensors.items()} + ) + + need_dummy_logits = not is_profile and lmhead_tp_enable() + max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + if not need_dummy_logits: + return None + return self.model.compute_logits(hidden_states[dummy_indices]) + + def dummy_drafter_compute_logits(hidden_states): + if not need_dummy_logits or self.drafter is None: + return + if hasattr(self.drafter, "model") and hasattr(self.drafter.model, "compute_logits"): + return self.drafter.model.compute_logits(hidden_states[dummy_indices]) + + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + in_profile_run=is_profile, + num_actual_tokens=0, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + model_instance=self.model, + ): + hidden_states = self._generate_dummy_run_hidden_states( + input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds + ) + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens_padded, + with_prefill=with_prefill, + num_reqs=num_reqs_padded, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + dummy_compute_logits=dummy_drafter_compute_logits, + in_graph_capturing=not force_attention, + is_profile=is_profile, + ) + if is_profile and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + # -------------------------------------- Omni-new ------------------------------------------------- + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + # ------------------------------------------------------------------------------------------------- + return hidden_states + + def profile_run(self) -> None: + # Trigger compilation for general shape. + with self.set_in_profile_run(): + hidden_states = self._dummy_run( + self.max_num_tokens // self.pcp_size if self.pcp_size > 1 else self.max_num_tokens, with_prefill=True + ) + # MC2 will consume additional NPU memory. + # Therefore, we need to run the MC2 path once here to complete its initialization, + # allowing vLLM to correctly estimate the maximum memory required. + mc2_tokens_capacity = get_mc2_tokens_capacity() + if ( + self.max_num_tokens > mc2_tokens_capacity + and self._select_moe_comm_method(mc2_tokens_capacity) == MoECommType.MC2 + ): + self._dummy_run(mc2_tokens_capacity, with_prefill=True) + + output = None + + NPUPlatform.synchronize() + del hidden_states, output + self.encoder_cache.clear() + gc.collect() diff --git a/vllm_omni/platforms/npu/worker/npu_generation_worker.py b/vllm_omni/platforms/npu/worker/npu_generation_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..edbaa9f64a5360646446ec83c3e6edd72250d09a --- /dev/null +++ b/vllm_omni/platforms/npu/worker/npu_generation_worker.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.worker.workspace import init_workspace_manager +from vllm_ascend.worker.worker import NPUWorker + +from vllm_omni.platforms.npu.worker.npu_generation_model_runner import NPUGenerationModelRunner +from vllm_omni.worker.mixins import OmniWorkerMixin + + +class NPUGenerationWorker(OmniWorkerMixin, NPUWorker): + """NPU generation worker for code2wav stage in Omni model.""" + + def init_device(self): + self.device = self._init_device() + num_ubatches = 1 + init_workspace_manager(self.device, num_ubatches) + + self.model_runner = NPUGenerationModelRunner(self.vllm_config, self.device) diff --git a/vllm_omni/platforms/npu/worker/npu_model_runner.py b/vllm_omni/platforms/npu/worker/npu_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a376f95fa4fa9994c071d0754a2b5d10fa0fa2 --- /dev/null +++ b/vllm_omni/platforms/npu/worker/npu_model_runner.py @@ -0,0 +1,1056 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +import torch +from vllm.config import CUDAGraphMode +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.interfaces import supports_mrope +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors +from vllm.utils.math_utils import cdiv +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper +from vllm_ascend.ops.rotary_embedding import update_cos_sin +from vllm_ascend.utils import enable_sp, lmhead_tp_enable +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +class OmniNPUModelRunner(NPUModelRunner): + """ + Base class for NPU model runners with multimodality support. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._omni_per_req_additional_information: dict[str, dict] | None = None + self._omni_num_scheduled_tokens_np: np.ndarray | None = None + self._omni_last_model_output: object | None = None + + def load_model(self, *args, **kwargs) -> None: + super().load_model(*args, **kwargs) + # TODO move this model specific logic to a separate class + if hasattr(self.model, "talker_mtp") and self.model.talker is not None: + self.talker_mtp = self.model.talker_mtp + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None + if cudagraph_mode.has_full_cudagraphs(): + self.talker_mtp = ACLGraphWrapper( + self.model.talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) + hidden_size = self.model_config.hf_config.talker_config.text_config.hidden_size + max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size) + self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32) + self.talker_mtp_inputs_embeds = self._make_buffer( + max_batch_size, hidden_size, dtype=self.dtype, numpy=False + ) + self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) + + if supports_mrope(self.get_model()): + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + mm_features=req_state.mm_features, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint, + # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds + # apart from the forced-preemption case in reset_prefix_cache. And in + # that case we include the resumed_req_ids in the unscheduled set so + # that they get cleared from the persistent batch before being re-scheduled + # in the normal resumed request path. + unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids) + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + reqs_to_add: list[CachedRequestState] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if sampling_params and sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if self.is_pooling_model: + assert pooling_params is not None + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" + + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + req_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, + mm_features=new_req_data.mm_features, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + self.requests[req_id] = req_state + + # -------------------------------------- Omni-new ------------------------------------------------- + # If prompt embeddings are provided, decode and attach to inter_data + try: + if getattr(new_req_data, "prompt_embeds", None) is not None: + payload = new_req_data.prompt_embeds + dtype = getattr(np, payload.dtype) + arr = np.frombuffer(payload.data, dtype=dtype) + arr = arr.reshape(payload.shape) + pe_cpu = torch.from_numpy(arr) + # Store temporarily on CPU; later moved to device in builder + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # Also replace payload with Tensor for user visibility in + # scheduler_output + try: + new_req_data.prompt_embeds = pe_cpu # type: ignore[assignment] + except Exception: + pass + except Exception as e: + logger.error(f"Error decoding prompt embeds: {e}") + # Decode additional_information payloads (dictionary) + try: + if getattr(new_req_data, "additional_information", None) is not None: + payload_info = new_req_data.additional_information + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + else: + from vllm_omni.engine import AdditionalInformationPayload + + if isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr.copy()) + else: + info_dict[k] = entry.list_data + if info_dict: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding additional information: {e}") + pass + # -------------------------------------- Omni-new ------------------------------------------------- + + if sampling_params and sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = ( + self.input_batch.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(req_state) + + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._init_xdrope_positions(req_state) + + reqs_to_add.append(self.requests[req_id]) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_id in req_data.resumed_req_ids + num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + if req_state.prev_num_draft_len and self.use_async_scheduling: + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt length), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = num_computed_tokens + len(new_token_ids) - req_state.num_tokens + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = self.input_batch.num_prompt_tokens[req_index] + num_output_tokens + self.input_batch.num_tokens_no_spec[req_index] = end_idx + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + else: + assert req_index is None + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] + + reqs_to_add.append(req_state) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens + if new_block_ids is not None: + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for request in reqs_to_add: + self.input_batch.add_request(request) + self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + @torch.inference_mode() + def extract_multimodal_outputs(self, hidden_states: torch.Tensor | list[torch.Tensor] | OmniOutput) -> dict: + if ( + hasattr(self.model, "have_multimodal_outputs") + and self.model.have_multimodal_outputs + and isinstance(hidden_states, OmniOutput) + ): + text_hidden_states = hidden_states.text_hidden_states + multimodal_outputs = hidden_states.multimodal_outputs + + elif isinstance(hidden_states, torch.Tensor): + text_hidden_states = hidden_states + multimodal_outputs = {} + elif isinstance(hidden_states, list) or isinstance(hidden_states, tuple): + text_hidden_states = hidden_states[0] + multimodal_outputs = {} + else: + raise ValueError(f"Invalid hidden states type: {type(hidden_states)}") + return text_hidden_states, multimodal_outputs + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + is_profile: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, + ) -> torch.Tensor: + # only support eager mode and piecewise graph now + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { + CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE, + CUDAGraphMode.FULL, + } + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + # Force dummy run on prefill stage when this node is deemed as kv producer. + if self.is_kv_producer and not self.is_kv_consumer: + with_prefill = True + + has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) + + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp( + batch_descriptor.num_tokens, with_prefill + ) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.max_num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + num_reqs = (num_tokens + self.decode_token_per_req - 1) // self.decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + + if not is_profile and self.dynamic_eplb: + self.eplb_updator.forward_before() + + if num_tokens != batch_descriptor.num_tokens: + _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora + ) + + num_tokens_padded = batch_descriptor.num_tokens + num_reqs_padded = batch_descriptor.num_reqs if batch_descriptor.num_reqs is not None else num_reqs + if num_tokens_across_dp is not None and num_tokens_padded != num_tokens: + # pad is needed if the pad of `num_tokens` is triggered inside CudagraphDispatcher + num_tokens_across_dp[:] = num_tokens_padded + num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) + + # filter out the valid batch descriptor + if cudagraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode: + raise ValueError( + f"Aclgraph runtime mode mismatch at dummy_run. " + f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}." + ) + else: + cudagraph_runtime_mode = _ag_mode + + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( + False, + num_reqs=num_reqs_padded, + num_tokens=num_tokens_padded, + max_query_len=max_query_len, + aclgraph_runtime_mode=cudagraph_runtime_mode, + force_attention=force_attention, + is_graph_capturing=is_graph_capturing, + num_scheduled_tokens=num_scheduled_tokens, + ) + + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, num_sampled_tokens): + # Make sure padding doesn't exceed max_num_tokens + assert num_tokens_padded <= self.max_num_tokens + if self.is_multimodal_model and not self.model_config.is_encoder_decoder: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + else: + input_ids = self.input_ids.gpu[:num_tokens_padded] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_tokens_padded] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] + else: + positions = self.positions.gpu[:num_tokens_padded] + + # update global cos, sin + update_cos_sin(positions) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + # When PP and flashcomm1 are enabled, + # during dummy_run the estimated space should divide num_tokens by tp_size; + # otherwise, on non-first PP ranks it would effectively perform an extra all-gather, + # leading to incorrect memory estimation and potentially causing OOM. + actual_tokens = num_tokens + if enable_sp(): + tp_size = get_tensor_model_parallel_world_size() + actual_tokens = num_tokens // tp_size + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=actual_tokens, dtype=self.dtype, device=self.device + ) + intermediate_tensors = IntermediateTensors( + {k: v[:num_tokens_padded] for k, v in self.intermediate_tensors.items()} + ) + + need_dummy_logits = not is_profile and lmhead_tp_enable() + max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len + dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + if not need_dummy_logits: + return None + return self.model.compute_logits(hidden_states[dummy_indices]) + + def dummy_drafter_compute_logits(hidden_states): + if not need_dummy_logits or self.drafter is None: + return + if hasattr(self.drafter, "model") and hasattr(self.drafter.model, "compute_logits"): + return self.drafter.model.compute_logits(hidden_states[dummy_indices]) + + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + in_profile_run=is_profile, + num_actual_tokens=0, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + model_instance=self.model, + ): + if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): + num_tokens_padded_talker_mtp = num_tokens_padded + if num_tokens_padded_talker_mtp == self.max_num_tokens: + num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0] + hidden_states = self.talker_mtp( + self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp], + self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp], + self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp], + self.text_step.gpu[:num_tokens_padded_talker_mtp], + ) + self.compilation_config.cache_dir = None + hidden_states = self._generate_dummy_run_hidden_states( + input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds + ) + dummy_compute_logits(hidden_states) + + hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens_padded, + with_prefill=with_prefill, + num_reqs=num_reqs_padded, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + dummy_compute_logits=dummy_drafter_compute_logits, + in_graph_capturing=not force_attention, + is_profile=is_profile, + ) + if is_profile and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + return hidden_states, hidden_states + + def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput") -> None: + """Decode per-request prompt_embeds and additional_information for newly + scheduled requests and store them to CPU in the request state. + This version avoids hard dependency on payload classes by duck-typing.""" + try: + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + if not new_reqs: + return + for nr in new_reqs: + req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) + if req_id is None: + continue + # prompt_embeds + payload_pe = getattr(nr, "prompt_embeds", None) + pe_cpu = None + if payload_pe is not None: + if isinstance(payload_pe, torch.Tensor): + pe_cpu = payload_pe.detach().to("cpu").contiguous() + else: + # Try duck-typing a payload with data/shape/dtype + data = getattr(payload_pe, "data", None) + shape = getattr(payload_pe, "shape", None) + if data is not None and shape is not None: + dt = np.dtype(getattr(payload_pe, "dtype", "float32")) + arr = np.frombuffer(data, dtype=dt) + arr = arr.reshape(shape) + pe_cpu = torch.from_numpy(arr.copy()) + if pe_cpu is not None and req_id in self.requests: + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # additional_information + payload_info = getattr(nr, "additional_information", None) + if payload_info is not None: + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + else: + # Try duck-typing a payload with entries, each entry may have + # tensor_data/tensor_dtype/tensor_shape or list_data + entries = getattr(payload_info, "entries", None) + if isinstance(entries, dict): + for k, entry in entries.items(): + tensor_data = getattr(entry, "tensor_data", None) + if tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(tensor_data, dtype=dt) + arr = arr.reshape(getattr(entry, "tensor_shape", ())) + info_dict[k] = torch.from_numpy(arr.copy()) + else: + info_dict[k] = getattr(entry, "list_data", None) + if info_dict and req_id in self.requests: + setattr(self.requests[req_id], "additional_information_cpu", info_dict) + except Exception as e: + logger.error(f"Error decoding prompt_embeds / additional_information: {e}") + + def _gather_runtime_additional_information(self) -> list[dict]: + """Gather per-request additional_information stored in request state in batch order.""" + per_req_runtime_info = [] + for req_id in self.input_batch.req_ids: + req_state = self.requests.get(req_id) + info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + if info and isinstance(info, dict): + per_req_runtime_info.append(info) + if "thinker_reply_part_per_request" in info: + q = info["thinker_reply_part_per_request"] + if hasattr(q, "shape"): + logger.debug(f"[OMNI] req={req_id} has thinker_reply_part_per_request queue shape: {q.shape}") + else: + per_req_runtime_info.append({}) + return per_req_runtime_info + + def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[int, int]]: + """Compute (start, end) token spans for each request within the flattened step sequence.""" + req_token_spans: list[tuple[int, int]] = [] + for req_index in range(len(self.input_batch.req_ids)): + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + req_token_spans.append((start_offset, start_offset + sched_tokens)) + return req_token_spans + + def _build_model_kwargs_extra(self) -> dict: + """Build extra keyword arguments passed to the model for this step, including: + - runtime_additional_information: per-request additional information stored in request state + """ + model_kwargs_extra: dict[str, object] = {} + try: + model_kwargs_extra["runtime_additional_information"] = self._gather_runtime_additional_information() + except Exception as e: + logger.error(f"[OMNI DEBUG] Error building model_kwargs_extra: {e}") + import traceback + + traceback.print_exc() + return model_kwargs_extra + + def _process_additional_information_updates( + self, + hidden_states: torch.Tensor, + multimodal_outputs: object, + num_scheduled_tokens_np: np.ndarray, + scheduler_output: "SchedulerOutput", + ) -> None: + """Process model-provided per-request additional_information updates and merge into request state.""" + try: + # execute the custom postprocess function + # TODO(Peiqi): do we have a more elegant way to do this? + if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: + for req_index, req_id in enumerate(self.input_batch.req_ids): + if self.model_config.async_chunk: + req_infos = self._get_additional_information(scheduler_output, req_id) + else: + req_state = self.requests.get(req_id) + req_infos = ( + getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + # only consider to store data into update dict. + hidden_states_slice = hidden_states[s:e] + update_dict = self.model.postprocess(hidden_states_slice, **req_infos) + self._merge_additional_information_update(req_id, update_dict) + except Exception as e: + logger.error( + f"Error merging for requests:{self.input_batch.req_ids} " + f"additional information update: {e}, with the multimodal_outputs " + f"as {multimodal_outputs}" + ) + import traceback + + traceback.print_exc() + + def _collect_additional_information_for_prefill( + self, + num_scheduled_tokens_np: np.ndarray, + ) -> dict[str, dict]: + """Overlay per-request prompt_embeds for the prefill portion and collect + additional_information slices for this step. Returns a map req_id -> dict.""" + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + if overlay_len > 0 and pe_cpu is not None: + src = pe_cpu[num_computed_tokens : num_computed_tokens + overlay_len].to( + dtype=self.dtype, device=self.device, non_blocking=True + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + + def _get_additional_information(self, scheduler_output: "SchedulerOutput", req_id: str) -> dict: + req_infos = None + req_state = self.requests.get(req_id) + additional_information_cpu = getattr(req_state, "additional_information_cpu", None) + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id == req_id: + payload_info = getattr(new_req, "additional_information", None) + if payload_info is not None: + return payload_info + + if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"): + cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {}) + if isinstance(cached_infos, dict) and req_id in cached_infos: + req_infos = cached_infos[req_id] + if not isinstance(req_infos, dict): + req_infos = None + + if req_infos is None or req_infos.get("last_talker_hidden", None) is None: + if req_infos is None: + additional_information_cpu.pop("thinker_embeddings", None) + req_infos = additional_information_cpu + else: + req_infos["last_talker_hidden"] = additional_information_cpu.get("last_talker_hidden", None) + req_infos["num_processed_thinker_tokens"] = additional_information_cpu.get( + "num_processed_thinker_tokens", 0 + ) + if not isinstance(req_infos, dict): + req_infos = None + + if req_infos is None: + logger.warning(f"No additional_information found for req_id: {req_id}") + + return req_infos + + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + num_input_tokens: int, + intermediate_tensors: IntermediateTensors | None = None, + ): + """Align with v0.14.0 preprocess and omni's additional information handling.""" + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + is_first_rank = get_pp_group().is_first_rank + is_encoder_decoder = self.model_config.is_encoder_decoder + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + ec_connector_output = None + + if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder: + # Run the multimodal encoder if any. + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + inputs_embeds_scheduled = self.model.embed_input_ids( + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) + + input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens) + model_kwargs = { + **self._init_model_kwargs(), + **self._extract_mm_kwargs(scheduler_output), + } + elif self.enable_prompt_embeds and is_first_rank: + # Get the input embeddings for the tokens that are not input embeds, + # then put them into the appropriate positions. + # TODO(qthequartermasterman): Since even when prompt embeds are + # enabled, (a) not all requests will use prompt embeds, and (b) + # after the initial prompt is processed, the rest of the generated + # tokens will be token ids, it is not desirable to have the + # embedding layer outside of the CUDA graph all the time. The v0 + # engine avoids this by "double compiling" the CUDA graph, once + # with input_ids and again with inputs_embeds, for all num_tokens. + # If a batch only has token ids, then including the embedding layer + # in the CUDA graph will be more performant (like in the else case + # below). + token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens].nonzero(as_tuple=False).squeeze(1) + # Some tokens ids may need to become embeds + if token_ids_idx.numel() > 0: + token_ids = self.input_ids.gpu[token_ids_idx] + tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids) + self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs() + input_ids = self.input_ids.gpu[:num_input_tokens] + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids.gpu[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs() + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_input_tokens] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_input_tokens] + else: + positions = self.positions.gpu[:num_input_tokens] + + if is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True + ) + + if is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: + # Run the encoder, just like we do with other multimodal inputs. + # For an encoder-decoder model, our processing here is a bit + # simpler, because the outputs are just passed to the decoder. + # We are not doing any prompt replacement. We also will only + # ever have a single encoder input. + encoder_outputs = self._execute_mm_encoder(scheduler_output) + model_kwargs.update({"encoder_outputs": encoder_outputs}) + + req_ids = self.input_batch.req_ids + num_scheduled_tokens_np = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], + dtype=np.int32, + ) + self._omni_num_scheduled_tokens_np = num_scheduled_tokens_np + + # Note: only prefill need collect additional_information for now. + # Decode don't need per_req_additional_information anymore. + if inputs_embeds is not None: + # Prefill: overlay prompt_embeds and collect additional_information + self._collect_additional_information_for_prefill(num_scheduled_tokens_np) + + if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: + # Overlay custom prompt_embeds per request for the prompt portion; + # collect additional_information (tensor/list) for prefill portion only + decode_req_ids = [] + for req_index, req_id in enumerate(self.input_batch.req_ids): + # Try to get additional_information from multiple sources + if self.vllm_config.model_config.async_chunk: + req_infos = self._get_additional_information(scheduler_output, req_id) + else: + req_state = self.requests.get(req_id) + req_infos = ( + getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + span_len = int(e) - int(s) + + # call the custom process function + req_input_ids, req_embeds, update_dict = self.model.preprocess( + input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos + ) + if hasattr(self.model, "talker_mtp") and span_len == 1: + last_talker_hidden, text_step = update_dict.pop("mtp_inputs") + decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) + self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) + self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) + self.text_step.gpu[decode_slice].copy_(text_step) + decode_req_ids.append(req_id) + + # TODO(Peiqi): the merge stage could move out from the critical path + self._merge_additional_information_update(req_id, update_dict) + + # update the inputs_embeds and input_ids + seg_len = min(span_len, req_embeds.shape[0]) + inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] + if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: + input_ids[s : s + seg_len] = req_input_ids + + # run talker mtp decode + if hasattr(self.model, "talker_mtp"): + self._talker_mtp_forward(decode_req_ids, inputs_embeds) + + return ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) + + def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: + decode_batch_size = len(decode_req_ids) + if decode_batch_size == 0: + return + _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( + num_tokens=decode_batch_size, + num_reqs=decode_batch_size, + num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size] + req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size] + last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size] + text_step = self.text_step.gpu[:decode_batch_size] + with set_ascend_forward_context( + None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc + ): + req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # update the inputs_embeds and code_predictor_codes + code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + for idx, req_id in enumerate(decode_req_ids): + req_index = self.input_batch.req_ids.index(req_id) + start_offset = int(self.query_start_loc.cpu[req_index]) + inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] + update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + self._merge_additional_information_update(req_id, update_dict) + + def _model_forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], + ): + """Inject omni-specific kwargs into forward and cache model output""" + model_kwargs_extra = self._build_model_kwargs_extra() + + runtime_info = model_kwargs_extra.get("runtime_additional_information", []) + if runtime_info: + for i, info in enumerate(runtime_info): + if info: + logger.debug(f"[OMNI] req[{i}] runtime_additional_information keys: {list(info.keys())}") + + model_output = super()._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + **model_kwargs_extra, + ) + if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"): + model_output = self.model.make_omni_output(model_output, **model_kwargs_extra) + # Cache model output so later sample_tokens can consume multimodal results. + self._omni_last_model_output = model_output + return model_output + + def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: + req_state = self.requests.get(req_id) + if req_state is None: + return + existing = getattr(req_state, "additional_information_cpu", {}) + if not isinstance(existing, dict): + existing = {} + merged = dict(existing) + for k, v in upd.items(): + if isinstance(v, torch.Tensor): + merged[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + merged[k] = [ + (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v + ] + else: + merged[k] = v + setattr(req_state, "additional_information_cpu", merged) diff --git a/vllm_omni/platforms/rocm/__init__.py b/vllm_omni/platforms/rocm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..999dc5a086d6d40e7d4106e8d6dcf051e79de8c7 --- /dev/null +++ b/vllm_omni/platforms/rocm/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.platforms.rocm.platform import RocmOmniPlatform + +__all__ = ["RocmOmniPlatform"] diff --git a/vllm_omni/platforms/rocm/platform.py b/vllm_omni/platforms/rocm/platform.py new file mode 100644 index 0000000000000000000000000000000000000000..14534f3a13cfd1045c6083a54edb7b7a82f3e8c5 --- /dev/null +++ b/vllm_omni/platforms/rocm/platform.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger +from vllm.platforms.rocm import RocmPlatform + +from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum +from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum + +logger = init_logger(__name__) + + +class RocmOmniPlatform(OmniPlatform, RocmPlatform): + """ROCm/AMD GPU implementation of OmniPlatform. + + Inherits all ROCm-specific implementations from vLLM's RocmPlatform, + and adds Omni-specific interfaces from OmniPlatform. + """ + + _omni_enum = OmniPlatformEnum.ROCM + + @classmethod + def get_omni_ar_worker_cls(cls) -> str: + return "vllm_omni.worker.gpu_ar_worker.GPUARWorker" + + @classmethod + def get_omni_generation_worker_cls(cls) -> str: + return "vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker" + + @classmethod + def get_diffusion_attn_backend_cls( + cls, + selected_backend: str | None, + head_size: int, + ) -> str: + from vllm._aiter_ops import is_aiter_found_and_supported + + # Check if aiter is available for Flash Attention support + # aiter currently only is supported on gfx942 and gfx950 + # https://github.com/vllm-project/vllm/blob/main/vllm/_aiter_ops.py + compute_capability = torch.cuda.get_device_capability() + major, minor = compute_capability + capability = major * 10 + minor + aiter_supported = is_aiter_found_and_supported() and 90 < capability < 100 + + if selected_backend is not None: + backend_upper = selected_backend.upper() + if backend_upper == "FLASH_ATTN" and not aiter_supported: + logger.warning( + "Flash Attention requires `aiter` library which is only supported " + "on gfx942 and gfx950. Falling back to TORCH_SDPA backend." + ) + logger.info("Defaulting to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + backend = DiffusionAttentionBackendEnum[backend_upper] + logger.info("Using diffusion attention backend '%s'", backend_upper) + return backend.get_path() + + # Choose to enable Flash Attention by default on ROCm + # whenever possible as it is the fastest backend + if aiter_supported: + logger.info("Defaulting to diffusion attention backend FLASH_ATTN") + return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path() + + logger.info("Defaulting to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + + @classmethod + def supports_torch_inductor(cls) -> bool: + return True + + @classmethod + def get_default_stage_config_path(cls) -> str: + return "vllm_omni/platforms/rocm/stage_configs" + + @classmethod + def get_torch_device(cls, local_rank: int | None = None) -> torch.device: + if local_rank is None: + return torch.device("cuda") + return torch.device("cuda", local_rank) + + @classmethod + def get_device_count(cls) -> int: + return torch.cuda.device_count() + + @classmethod + def get_device_version(cls) -> str | None: + if torch.version.hip is not None: + hip_version = torch.version.hip + return hip_version.split("-")[0] + return None + + @classmethod + def synchronize(cls) -> None: + torch.cuda.synchronize() + + @classmethod + def get_free_memory(cls, device: torch.device | None = None) -> int: + free, _ = torch.cuda.mem_get_info(device) + return free diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7887cd2bb0b54ccc3eb7675ba6de68b6f76cf0ec --- /dev/null +++ b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml @@ -0,0 +1,102 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. + +# The following config has been verified on 2x H100-80G GPU. +stage_args: + - stage_id: 0 + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + + - stage_id: 1 + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + + - stage_id: 2 + runtime: + process: true + devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31312673ae8c5bb5690160f951523f710aae7ff4 --- /dev/null +++ b/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml @@ -0,0 +1,97 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 1 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + # tensor_parallel_size: 2 + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/vllm_omni/platforms/xpu/__init__.py b/vllm_omni/platforms/xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec0986c2e8b6291797f74e9a810da65770e3033 --- /dev/null +++ b/vllm_omni/platforms/xpu/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.platforms.xpu.platform import XPUOmniPlatform + +__all__ = ["XPUOmniPlatform"] diff --git a/vllm_omni/platforms/xpu/platform.py b/vllm_omni/platforms/xpu/platform.py new file mode 100644 index 0000000000000000000000000000000000000000..43d63570285e607258d097f433d26881bb8d690a --- /dev/null +++ b/vllm_omni/platforms/xpu/platform.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger +from vllm.platforms.xpu import XPUPlatform + +from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum +from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum + +logger = init_logger(__name__) + + +class XPUOmniPlatform(OmniPlatform, XPUPlatform): + """XPU/Intel GPU implementation of OmniPlatform. + + Inherits all XPU-specific implementations from vLLM's XPUPlatform, + and adds Omni-specific interfaces from OmniPlatform. + """ + + _omni_enum = OmniPlatformEnum.XPU + + @classmethod + def get_omni_ar_worker_cls(cls) -> str: + return "vllm_omni.platforms.xpu.worker.xpu_ar_worker.XPUARWorker" + + @classmethod + def get_omni_generation_worker_cls(cls) -> str: + return "vllm_omni.platforms.xpu.worker.xpu_generation_worker.XPUGenerationWorker" + + @classmethod + def get_diffusion_attn_backend_cls( + cls, + selected_backend: str | None, + head_size: int, + ) -> str: + if selected_backend is not None: + backend_upper = selected_backend.upper() + backend = DiffusionAttentionBackendEnum[backend_upper] + logger.info("Using diffusion attention backend '%s'", backend_upper) + return backend.get_path() + + logger.info("Defaulting to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + + @classmethod + def supports_torch_inductor(cls) -> bool: + return True + + @classmethod + def get_default_stage_config_path(cls) -> str: + return "vllm_omni/platforms/xpu/stage_configs" + + @classmethod + def get_torch_device(cls, local_rank: int | None = None) -> torch.device: + if local_rank is None: + return torch.device("xpu") + return torch.device("xpu", local_rank) + + @classmethod + def get_device_count(cls) -> int: + return torch.xpu.device_count() + + @classmethod + def get_device_version(cls) -> str | None: + # XPU does not have a version string like CUDA + return None + + @classmethod + def synchronize(cls) -> None: + torch.xpu.synchronize() + + @classmethod + def get_free_memory(cls, device: torch.device | None = None) -> int: + if device is None: + device_id = 0 + else: + device_id = device.index if device.index is not None else 0 + props = torch.xpu.get_device_properties(device_id) + return props.total_memory diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fe1b79e45b132987eb0cdbdf9e449b28e995d55 --- /dev/null +++ b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml @@ -0,0 +1,101 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. + +# The following config has been verified on 2x 1550-64G XPUs. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true # Run this stage in a separate process + devices: "0" # Visible devices for this stage + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: false + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + stop_token_ids: [8294] + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + process: true + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + +# Top-level runtime config (concise): default windows and stage edges +runtime: + enabled: true + defaults: + window_size: -1 # Simplified: trigger downstream only after full upstream completion + max_inflight: 1 # Simplified: process serially within each stage + + edges: + - from: 0 # thinker → talker: trigger only after receiving full input (-1) + to: 1 + window_size: -1 + - from: 1 # talker → code2wav: trigger only after receiving full input (-1) + to: 2 + window_size: -1 diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0803d735342b3cc65313d568da7ef0181dbd2cde --- /dev/null +++ b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml @@ -0,0 +1,99 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) + +# The following config has been verified on 4x 1550-64G XPUs. +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + hf_config_name: thinker_config + tensor_parallel_size: 2 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "2" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.3 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + # final_output: true + # final_output_type: text + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM + runtime: + devices: "3" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/vllm_omni/platforms/xpu/utils.py b/vllm_omni/platforms/xpu/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..18437cae3893ac9f89d6b69c7e61a0f4899d3417 --- /dev/null +++ b/vllm_omni/platforms/xpu/utils.py @@ -0,0 +1,16 @@ +from contextlib import contextmanager + +import torch + + +@contextmanager +def torch_cuda_wrapper(): + try: + # replace cuda APIs with xpu APIs, this should work by default + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.default_stream = torch.xpu.current_stream + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.stream = torch.xpu.stream + yield + finally: + pass diff --git a/vllm_omni/platforms/xpu/worker/__init__.py b/vllm_omni/platforms/xpu/worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/platforms/xpu/worker/xpu_ar_model_runner.py b/vllm_omni/platforms/xpu/worker/xpu_ar_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..2235d354877061671ec704fb8c12d85c14f6849b --- /dev/null +++ b/vllm_omni/platforms/xpu/worker/xpu_ar_model_runner.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm_omni.platforms.xpu.utils import torch_cuda_wrapper +from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner + + +class XPUARModelRunner(GPUARModelRunner): + def __init__(self, *args, **kwargs): + with torch_cuda_wrapper(): + super().__init__(*args, **kwargs) + + def _init_device_properties(self): + self.num_sms = None + + def _sync_device(self) -> None: + torch.xpu.synchronize() diff --git a/vllm_omni/platforms/xpu/worker/xpu_ar_worker.py b/vllm_omni/platforms/xpu/worker/xpu_ar_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6349264596762ecf85410921c5f822a2854bbb --- /dev/null +++ b/vllm_omni/platforms/xpu/worker/xpu_ar_worker.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.worker.xpu_worker import XPUWorker + +from vllm_omni.platforms.xpu.worker.xpu_ar_model_runner import XPUARModelRunner +from vllm_omni.worker.mixins import OmniWorkerMixin + + +class XPUARWorker(OmniWorkerMixin, XPUWorker): + """XPU AR worker for thinker/talker stages in Omni model.""" + + def init_device(self): + super().init_device() + self.model_runner: XPUARModelRunner = XPUARModelRunner(self.vllm_config, self.device) diff --git a/vllm_omni/platforms/xpu/worker/xpu_generation_model_runner.py b/vllm_omni/platforms/xpu/worker/xpu_generation_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..e865ac5e0951cac9d92bbdba0a4a7547f5a15162 --- /dev/null +++ b/vllm_omni/platforms/xpu/worker/xpu_generation_model_runner.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm_omni.platforms.xpu.utils import torch_cuda_wrapper +from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner + + +class XPUGenerationModelRunner(GPUGenerationModelRunner): + def __init__(self, *args, **kwargs): + with torch_cuda_wrapper(): + super().__init__(*args, **kwargs) + + def _init_device_properties(self): + self.num_sms = None + + def _sync_device(self) -> None: + torch.xpu.synchronize() diff --git a/vllm_omni/platforms/xpu/worker/xpu_generation_worker.py b/vllm_omni/platforms/xpu/worker/xpu_generation_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6d530eb0c4ee9632048e4ccbd814511f729d40 --- /dev/null +++ b/vllm_omni/platforms/xpu/worker/xpu_generation_worker.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.worker.xpu_worker import XPUWorker + +from vllm_omni.platforms.xpu.worker.xpu_generation_model_runner import XPUGenerationModelRunner +from vllm_omni.worker.mixins import OmniWorkerMixin + + +class XPUGenerationWorker(OmniWorkerMixin, XPUWorker): + """XPU generation worker for the code2wav (non-AR waveform generation) stage in the Omni model.""" + + def init_device(self): + super().init_device() + self.model_runner: XPUGenerationModelRunner = XPUGenerationModelRunner(self.vllm_config, self.device) diff --git a/vllm_omni/plugins/__init__.py b/vllm_omni/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41bf398cabf5a100b0984820ac10bcd3e419327d --- /dev/null +++ b/vllm_omni/plugins/__init__.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +from collections.abc import Callable +from typing import Any + +import vllm.envs as envs + +logger = logging.getLogger(__name__) + +# Default plugins group will be loaded in all processes (process0, engine core +# process and worker processes). +OMNI_DEFAULT_PLUGINS_GROUP = "vllm_omni.general_plugins" +# Platform plugins group will be loaded in all processes when +# `vllm_omni.platforms.current_omni_platform` is called and the value not +# initialized. +OMNI_PLATFORM_PLUGINS_GROUP = "vllm_omni.platform_plugins" + +# Make sure one process only loads plugins once. +omni_plugins_loaded = False + + +def load_omni_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group=group) + if len(discovered_plugins) == 0: + logger.debug("No plugins for group %s found.", group) + return {} + + # Check if the only discovered plugin is the default one. + is_default_group = group == OMNI_DEFAULT_PLUGINS_GROUP + # Use INFO for non-default groups and DEBUG for the default group. + log_level = logger.debug if is_default_group else logger.info + + log_level("Available plugins for group %s:", group) + for plugin in discovered_plugins: + log_level("- %s -> %s", plugin.name, plugin.value) + + if allowed_plugins is None: + log_level("All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.") + + plugins: dict[str, Callable[[], Any]] = {} + for plugin in discovered_plugins: + if allowed_plugins is None or plugin.name in allowed_plugins: + if allowed_plugins is not None: + log_level("Loading plugin %s", plugin.name) + + try: + func = plugin.load() + plugins[plugin.name] = func + except Exception: + logger.exception("Failed to load plugin %s", plugin.name) + + return plugins + + +def load_omni_general_plugins() -> None: + """WARNING: plugins can be loaded for multiple times in different + processes. They should be designed in a way that they can be loaded + multiple times without causing issues. + """ + global omni_plugins_loaded + if omni_plugins_loaded: + return + omni_plugins_loaded = True + + plugins = load_omni_plugins_by_group(group=OMNI_DEFAULT_PLUGINS_GROUP) + # General plugins: we only need to execute the loaded functions. + for func in plugins.values(): + func() + + +__all__ = [ + "load_omni_general_plugins", + "OMNI_DEFAULT_PLUGINS_GROUP", + "OMNI_PLATFORM_PLUGINS_GROUP", +] diff --git a/vllm_omni/request.py b/vllm_omni/request.py new file mode 100644 index 0000000000000000000000000000000000000000..c992816de6e191ec017584213b0f9160fe8f5793 --- /dev/null +++ b/vllm_omni/request.py @@ -0,0 +1,94 @@ +from collections.abc import Callable +from typing import TYPE_CHECKING + +import numpy as np +import torch +from vllm.v1.request import Request + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import BlockHash + +from vllm_omni.engine import AdditionalInformationPayload, OmniEngineCoreRequest, PromptEmbedsPayload + + +class OmniRequest(Request): + """Request class for omni models, extending the base Request. + + This class extends the base vLLM Request with support for prompt + embeddings and additional information payloads, enabling direct + transfer of pre-computed embeddings between stages. + + Args: + prompt_embeds: Optional serialized prompt embeddings payload. + Used for direct transfer of embeddings between stages. + additional_information: Optional additional information payload + containing tensors or lists to be passed along with the request. + """ + + def __init__( + self, + prompt_embeds: PromptEmbedsPayload | torch.Tensor | None = None, + # Optional external request ID for tracking + external_req_id: str | None = None, + additional_information: AdditionalInformationPayload | None = None, + *args, + **kwargs, + ): + prompt_embeds_tensor = self._maybe_decode_prompt_embeds(prompt_embeds) + super().__init__(prompt_embeds=prompt_embeds_tensor, *args, **kwargs) + # Preserve serialized prompt embeddings payload (optional) + self.prompt_embeds_payload: PromptEmbedsPayload | None = ( + prompt_embeds if isinstance(prompt_embeds, PromptEmbedsPayload) else None + ) + # Optional external request ID for tracking + self.external_req_id: str | None = external_req_id + # Serialized additional information payload (optional) + self.additional_information: AdditionalInformationPayload | None = additional_information + + @staticmethod + def _maybe_decode_prompt_embeds( + prompt_embeds: PromptEmbedsPayload | torch.Tensor | None, + ) -> torch.Tensor | None: + if isinstance(prompt_embeds, PromptEmbedsPayload): + dtype = getattr(np, prompt_embeds.dtype) + arr = np.frombuffer(prompt_embeds.data, dtype=dtype) + arr = arr.reshape(prompt_embeds.shape) + return torch.from_numpy(arr) + return prompt_embeds + + @classmethod + def from_engine_core_request( + cls, + request: OmniEngineCoreRequest, + block_hasher: Callable[["Request"], list["BlockHash"]] | None, + ) -> "Request": + """Create an OmniRequest from an OmniEngineCoreRequest. + + Args: + request: The OmniEngineCoreRequest to convert + block_hasher: Optional function to compute block hashes for + prefix caching + + Returns: + OmniRequest instance created from the engine core request + """ + return cls( + request_id=request.request_id, + # Optional external request ID for tracking + external_req_id=request.external_req_id, + client_index=request.client_index, + prompt_token_ids=request.prompt_token_ids, + prompt_embeds=request.prompt_embeds, + mm_features=request.mm_features, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + cache_salt=request.cache_salt, + priority=request.priority, + trace_headers=request.trace_headers, + block_hasher=block_hasher, + additional_information=request.additional_information, + resumable=request.resumable, + ) diff --git a/vllm_omni/sample/__init__.py b/vllm_omni/sample/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/utils/__init__.py b/vllm_omni/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/version.py b/vllm_omni/version.py new file mode 100644 index 0000000000000000000000000000000000000000..15253fe6fd6aacfee8c795b35e1dcd41b2d31a5c --- /dev/null +++ b/vllm_omni/version.py @@ -0,0 +1,3 @@ +__version__ = "0.14.0" +__version_tuple__ = (0, 14, 0) +# TODO: add auto version generation diff --git a/vllm_omni/worker/__init__.py b/vllm_omni/worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..9332abeb1a9f26f8077738191ecb2c295d16cad5 --- /dev/null +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -0,0 +1,619 @@ +"""AR GPU Model Runner for vLLM-Omni. + +Exposes per-request hidden representations via ModelRunnerOutput.pooler_output +and also outputs sampled tokens. +""" + +from __future__ import annotations + +from copy import copy +from typing import Any, NamedTuple + +import numpy as np +import torch +from vllm.config import CUDAGraphMode +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.structured_output.utils import apply_grammar_bitmask +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.gpu_model_runner import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncGPUModelRunnerOutput, + IntermediateTensors, +) +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices +from vllm.v1.worker.utils import is_residual_scattered_for_sp + +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + +logger = init_logger(__name__) + + +class ExecuteModelState(NamedTuple): + scheduler_output: SchedulerOutput + logits: torch.Tensor | None + spec_decode_metadata: Any + spec_decode_common_attn_metadata: Any + hidden_states: torch.Tensor + sample_hidden_states: torch.Tensor + aux_hidden_states: list[torch.Tensor] | None + ec_connector_output: Any + cudagraph_stats: Any + multimodal_outputs: Any + + +class GPUARModelRunner(OmniGPUModelRunner): + """Autoregressive GPU model runner that returns hidden states per request. + + Follows the v0.12 two-phase execute/sample flow from GPUModelRunner, and + reuses Omni hooks for additional_information / multimodal outputs. This + class only overrides sample_tokens to expose hidden states + multimodal + outputs per request while keeping Async output semantics. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + # each model stage has their own hidden size + self.hidden_size = self.model_config.hf_text_config.hidden_size + self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) + # Initialize KV cache manager (preserve vllm_config fallback behavior) + self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) + + def _make_buffer(self, *size, dtype, numpy=True): + # Prevent ray from pinning the buffer due to large size + from vllm_omni.distributed.ray_utils.utils import ( + calculate_total_bytes, + maybe_disable_pin_memory_for_ray, + ) + + total_bytes = calculate_total_bytes(size, dtype) + + # Use the context manager to temporarily disable pinning if needed + with maybe_disable_pin_memory_for_ray(self, total_bytes): + return super()._make_buffer(*size, dtype=dtype, numpy=numpy) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: + if self.execute_model_state is not None: + raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") + + # [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) + self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer( + finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}), + kv_caches=self.kv_caches, + block_size=self.cache_config.block_size, + cache_dtype=str(self.cache_config.cache_dtype), + request_id_resolver=self._resolve_global_request_id, + ) + + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): + get_kv_transfer_group().handle_preemptions(scheduler_output.preempted_req_ids) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with ( + record_function_or_nullcontext("gpu_model_runner: preprocess"), + self.synchronize_input_prep(), + ): + # Update persistent batch states. + self._update_states(scheduler_output) + + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + + logits_indices, spec_decode_metadata = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, + ) + + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) + + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + # True if any attention backend handles KV cache update separately + # from forward() (i.e., forward_includes_kv_cache_update=False). When true, + # slot_mappings must use padded dimensions to match the key/value tensors. + from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec + + has_separate_kv_update = not all( + all(g.backend.forward_includes_kv_cache_update for g in self.attn_groups[id]) + for id, spec in enumerate(self.kv_cache_config.kv_cache_groups) + if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec) + ) + + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens_padded if pad_attn or has_separate_kv_update else num_tokens_unpadded, + num_reqs_padded=(num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs), + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + + attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, + ) + + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) + + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False + + # Run the model. + # Use persistent buffers for CUDA graphs. + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, + ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, # OMNI: required for KV cache operations + ), + record_function_or_nullcontext("gpu_model_runner: forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + model_output = self._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + with record_function_or_nullcontext("gpu_model_runner: postprocess"): + if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. + hidden_states, aux_hidden_states = model_output + else: + # Common case. + hidden_states = model_output + aux_hidden_states = None + + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(model_output) + if multimodal_outputs is not None: + keys_or_type = ( + list(multimodal_outputs.keys()) + if isinstance(multimodal_outputs, dict) + else type(multimodal_outputs) + ) + logger.debug(f"[AR] execute_model: multimodal_outputs keys = {keys_or_type}") + else: + logger.debug("[AR] execute_model: multimodal_outputs is None") + + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + # Return the pooling output. + return self._pool( + hidden_states, + num_scheduled_tokens, + num_scheduled_tokens_np, + kv_connector_output, + ) + + sample_hidden_states = hidden_states[logits_indices] + # Try with sampling_metadata first; fall back to without for models that don't support it + try: + logits = self.model.compute_logits( + sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata + ) + except TypeError: + logits = self.model.compute_logits(sample_hidden_states) + else: + # Rare case. + assert not self.is_pooling_model + + sample_hidden_states = hidden_states[logits_indices] + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp(self.vllm_config, num_tokens_padded) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + logits = None + else: + # Try with sampling_metadata first; fall back to without for models that don't support it + try: + logits = self.model.compute_logits( + sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata + ) + except TypeError: + logits = self.model.compute_logits(sample_hidden_states) + + model_output_broadcast_data: dict[str, Any] = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + + broadcasted = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert broadcasted is not None + logits = broadcasted["logits"] + + self.execute_model_state = ExecuteModelState( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + cudagraph_stats, + multimodal_outputs, + ) + self.kv_connector_output = kv_connector_output + + return None + + @torch.inference_mode() + def sample_tokens( + self, + grammar_output: GrammarOutput | None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None) + self.kv_extracted_req_ids = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # type: ignore[return-value] + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + cudagraph_stats, + multimodal_outputs, + ) = self.execute_model_state + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits) + + with record_function_or_nullcontext("gpu_model_runner: sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + self._draft_token_ids = None + self._draft_token_req_ids = None + self.input_batch.prev_sampled_token_ids = None + + def propose_draft_token_ids(sampled_token_ids): + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("gpu_model_runner: draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + self._copy_draft_token_ids_to_cpu(scheduler_output) + + spec_config = self.speculative_config + propose_drafts_after_bookkeeping = False + if spec_config is not None: + input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( + spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + <= self.effective_drafter_max_model_len + ) + if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + assert isinstance(self.drafter, EagleProposer) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) + # Since we couldn't run the drafter, + # just use zeros for the draft tokens. + self._draft_token_ids = torch.zeros(1, device=self.device, dtype=torch.int32).expand( + len(self.input_batch.req_ids), self.num_spec_tokens + ) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + else: + propose_drafts_after_bookkeeping = input_fits_in_drafter + + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) + + if propose_drafts_after_bookkeeping: + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) + + with record_function_or_nullcontext("gpu_model_runner: eplb"): + self.eplb_step() + + hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() + num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None) + if num_scheduled_tokens_np is None: + req_ids = self.input_batch.req_ids + num_scheduled_tokens_np = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], + dtype=np.int32, + ) + + self._process_additional_information_updates( + hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + ) + + pooler_output: list[dict[str, object]] = [] + for rid in req_ids_output_copy: + idx = req_id_to_index_output_copy[rid] + start = int(self.query_start_loc.cpu[idx]) + sched = int(num_scheduled_tokens_np[idx]) + end = start + sched + hidden_slice = hidden_states_cpu[start:end] + payload: dict[str, object] = {"hidden": hidden_slice} + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + mm_payload: dict[str, object] = {} + for k, v in multimodal_outputs.items(): + try: + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + mm_payload[k] = v.detach().to("cpu")[start:end].contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: + sub_dict[str(sk)] = sv.detach().to("cpu")[start:end].contiguous() + if sub_dict: + mm_payload[k] = sub_dict + elif isinstance(v, list): + element = v[0] + if isinstance(element, torch.Tensor): + element = element.detach().to("cpu").contiguous() + mm_payload[k] = element + except Exception as e: + logger.error(f"Error in merge multimodal outputs: {e}") + if mm_payload: + payload.update(mm_payload) + pooler_output.append(payload) + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + if self.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.slot_mapping) # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + output = OmniModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), + kv_connector_output=kv_connector_output, + ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, + num_nans_in_logits=num_nans_in_logits, + cudagraph_stats=cudagraph_stats, + ) + output.kv_extracted_req_ids = kv_extracted_req_ids + + if not self.use_async_scheduling: + return output + with record_function_or_nullcontext("gpu_model_runner: AsyncGPUModelRunnerOutput"): + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + ) + with record_function_or_nullcontext("gpu_model_runner: set_async_sampled_token_ids"): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) + + return async_output + + def _resolve_global_request_id(self, req_id: str) -> str: + """Resolve global request ID from request state.""" + req_state = self.requests.get(req_id) + if not req_state: + return req_id + + add_info = getattr(req_state, "additional_information_cpu", {}) or {} + global_id = add_info.get("global_request_id") + if global_id: + if isinstance(global_id, list) and global_id: + global_id = global_id[0] + if isinstance(global_id, bytes): + return global_id.decode("utf-8") + return str(global_id) + return req_id diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..98b5afddab3943bfbfa4362cd97acce53770bd09 --- /dev/null +++ b/vllm_omni/worker/gpu_ar_worker.py @@ -0,0 +1,104 @@ +import gc +import os + +import torch +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.mem_utils import MemorySnapshot, format_gib +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.utils import report_usage_stats +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +from vllm.v1.worker.gpu_worker import init_worker_distributed_environment +from vllm.v1.worker.utils import request_memory +from vllm.v1.worker.workspace import init_workspace_manager + +from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner +from vllm_omni.worker.mixins import OmniWorkerMixin + +logger = init_logger(__name__) + + +class GPUARWorker(OmniWorkerMixin, GPUWorker): + """GPU worker for autoregressive omni model stages. + + Extends the base GPUWorker to initialize and manage autoregressive + model runners for text generation stages (e.g., thinker stages). + """ + + def init_device(self): + if self.device_config.device_type == "cuda": + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + parallel_config = self.parallel_config + if ( + parallel_config.distributed_executor_backend not in ("ray", "external_launcher") + and parallel_config.data_parallel_backend != "ray" + and parallel_config.nnodes_within_dp == 1 + ): + # Use local DP rank if available, otherwise use global DP rank. + dp_local_rank = self.parallel_config.data_parallel_rank_local + if dp_local_rank is None: + dp_local_rank = self.parallel_config.data_parallel_index + + tp_pp_world_size = ( + self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size + ) + + # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK + self.local_rank += dp_local_rank * tp_pp_world_size + assert self.local_rank < torch.cuda.device_count(), ( + f"DP adjusted local rank {self.local_rank} is out of bounds. " + ) + visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must " + f"be less than or equal to the number of visible devices " + f"({visible_device_count})." + ) + self.device = torch.device(f"cuda:{self.local_rank}") + current_platform.set_device(self.device) + + current_platform.check_if_supports_dtype(self.model_config.dtype) + + # Initialize the distributed environment BEFORE taking + # memory snapshot + # This ensures NCCL buffers are allocated before we measure + # available memory + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) + + # Set random seed. + set_random_seed(self.model_config.seed) + + # Now take memory snapshot after NCCL is initialized + gc.collect() + torch.cuda.empty_cache() + + # take current memory snapshot + self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) + self.requested_memory = request_memory(init_snapshot, self.cache_config) + logger.debug("worker init memory snapshot: %r", self.init_snapshot) + logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory)) + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) + + if self.use_v2_model_runner: + # OMNI: v2 model runner does not yet include omni hooks. + logger.warning("OMNI GPUARWorker forces v1 model runner for omni hooks.") + self.use_v2_model_runner = False + + # Construct the model runner + self.model_runner = GPUARModelRunner(self.vllm_config, self.device) + + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..9125cc52a7c48c067c68af484074b1483f49a386 --- /dev/null +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -0,0 +1,789 @@ +"""Code2Wav GPU Model Runner for vLLM-Omni. + +Handles direct conversion from codec codes to audio waveforms for Qwen3 Omni MoE Code2Wav. +This is a non-autoregressive model that doesn't require sampling or logits computation. +""" + +from __future__ import annotations + +import gc +import logging +from copy import copy + +import numpy as np +import torch +from vllm.config import CUDAGraphMode +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) +from vllm.utils.math_utils import cdiv +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.gpu_model_runner import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncGPUModelRunnerOutput, + IntermediateTensors, + PerLayerAttnMetadata, +) +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices +from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs + +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + +logger = logging.getLogger(__name__) + + +class GPUGenerationModelRunner(OmniGPUModelRunner): + """Generation model runner for vLLM-Omni (non-autoregressive). + + - Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue. + - Does not compute logits or perform token sampling. + - Executes generation process and returns tensors via `pooler_output`. + """ + + def _update_request_states(self, scheduler_output: SchedulerOutput): + cached_reqs = scheduler_output.scheduled_cached_reqs + for _, req_id in enumerate(cached_reqs.req_ids): + req_state = self.requests.get(req_id) + assert req_state is not None + req_state.prompt_token_ids = cached_reqs.prompt_token_ids.get(req_id) + self.input_batch.remove_request(req_id) + # update the request state in self.input_batch + self.input_batch.add_request(req_state) + self._init_mrope_positions(req_state) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: IntermediateTensors | None = None, + ) -> OmniModelRunnerOutput | IntermediateTensors: + if self.execute_model_state is not None: + raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") + + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): + get_kv_transfer_group().handle_preemptions(scheduler_output.preempted_req_ids) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with ( + record_function_or_nullcontext("gpu_model_runner: preprocess"), + self.synchronize_input_prep(), + ): + if self.model_config.async_chunk: + self._update_request_states(scheduler_output) + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + + logits_indices, spec_decode_metadata = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, + ) + + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) + + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + # OMNI: True if any attention backend handles KV cache update separately + # from forward() (i.e., forward_includes_kv_cache_update=False). When true, + # slot_mappings must use padded dimensions to match the key/value tensors. + from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec + + has_separate_kv_update = not all( + all(g.backend.forward_includes_kv_cache_update for g in self.attn_groups[id]) + for id, spec in enumerate(self.kv_cache_config.kv_cache_groups) + if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec) + ) + + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens_padded if pad_attn or has_separate_kv_update else num_tokens_unpadded, + num_reqs_padded=(num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs), + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + + attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, + ) + + ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) = self._preprocess( + scheduler_output, + num_tokens_padded, + intermediate_tensors, + ) + + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False + + # Run the model. + # Use persistent buffers for CUDA graphs. + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, + ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, # OMNI: required for KV cache operations + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + outputs = self._run_generation_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + model_kwargs=model_kwargs, + logits_indices=logits_indices, + ) + + _, multimodal_outputs = self.extract_multimodal_outputs(outputs) + self.execute_model_state = ExecuteModelState( + scheduler_output, + None, + spec_decode_metadata, + spec_decode_common_attn_metadata, + None, + None, + None, + ec_connector_output, + cudagraph_stats, + multimodal_outputs, + ) + self.kv_connector_output = kv_connector_output + return None + + @torch.inference_mode() + def sample_tokens( + self, + grammar_output: GrammarOutput | None = None, + ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + # NOTE: Even though the model is non-autoregressive, we still need + # this function to match the interface of the engine core. + # In this case, this function + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # type: ignore[return-value] + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + cudagraph_stats, + multimodal_outputs, + ) = self.execute_model_state + self.execute_model_state = None + + pooler_output: list[object] = [] + if isinstance(multimodal_outputs, torch.Tensor): + assert multimodal_outputs.shape[0] == 1, ( + "model should return a single tensor, to return multiple tensors, use a dict" + ) + assert multimodal_outputs.shape[0] == self.input_batch.num_reqs + for i in range(self.input_batch.num_reqs): + pooler_output.append({"model_outputs": multimodal_outputs[i].detach().to("cpu").contiguous()}) + elif isinstance(multimodal_outputs, list): + assert len(multimodal_outputs) == 1, ( + "model should return a single list, to return multiple lists, use a dict" + ) + for out in multimodal_outputs: + pooler_output.append( + {"model_outputs": out.detach().to("cpu").contiguous() if out is not None else None} + ) + elif isinstance(multimodal_outputs, dict): + mm_payload = {} + for key, out in multimodal_outputs.items(): + if out is not None and isinstance(out, torch.Tensor): + mm_payload[key] = out.detach().to("cpu").contiguous() + pooler_output.append(mm_payload) + else: + raise RuntimeError("Unsupported diffusion output type") + output = OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + num_nans_in_logits={}, + cudagraph_stats=cudagraph_stats, + ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, + ) + + if not self.use_async_scheduling: + return output + + return AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=torch.tensor([], device=self.device), + invalid_req_indices=[], + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + logprobs_tensors=None, + ) + + def _run_generation_model( + self, + *, + input_ids: torch.Tensor | None, + positions: torch.Tensor | None, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, + model_kwargs: dict, + logits_indices: torch.Tensor, + ) -> torch.Tensor | list[torch.Tensor]: + """Run generation from codec codes to waveforms. + + Args: + scheduler_output: Contains codec codes in input_ids or additional info + intermediate_tensors: PP intermediate tensors if applicable + + Returns: + Audio waveforms: [batch, 1, waveform_len] or list of tensors + """ + # Keep inputs identical to AR runner + kwargs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + if hasattr(self.model, "forward"): + return self._model_forward(**kwargs) + + raise RuntimeError( + "The loaded model does not expose diffusion interfaces 'sample', " + "'forward', or 'diffuse'. Please implement one of them or adapt the runner." + ) + + @torch.inference_mode() + def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: + logger.warning("Dummy sampler run is not implemented for generation model") + return None + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a dummy forward pass to warm up/profile run or capture the + CUDA graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: used to control the behavior. + - if not set will determine the cudagraph mode based on using + the self.cudagraph_dispatcher. + - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. + - CUDAGraphMode.FULL: Full cudagraph, attention metadata is + needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. + remove_lora: If False, dummy LoRAs are not destroyed after the run + activate_lora: If False, dummy_run is performed without LoRAs. + """ + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # The current dummy run only covers LM execution, so we can skip it. + # mm encoder dummy run may need to add in the future. + return torch.tensor([]), torch.tensor([]) + + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_tokens_unpadded = int(num_scheduled_tokens.sum()) + + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + + _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, + ) + ) + + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = _cudagraph_mode + else: + assert cudagraph_runtime_mode == _cudagraph_mode, ( + f"Cudagraph runtime mode mismatch in dummy_run. " + f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}." + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + attn_metadata: PerLayerAttnMetadata | None = None + + # OMNI: Get slot mappings before building attention metadata + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs_padded, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + + # If force_attention is True, we always capture attention. Otherwise, + # it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + seq_lens = max_query_len # type: ignore[assignment] + self.seq_lens.np[:num_reqs] = seq_lens + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, + for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, + ) + + with self.maybe_dummy_run_with_lora( + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + remove_lora, + ): + # Make sure padding doesn't exceed max_num_tokens + assert num_tokens_padded <= self.max_num_tokens + model_kwargs = self._init_model_kwargs() + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + + model_kwargs = { + **model_kwargs, + **self._dummy_mm_kwargs(num_reqs), + } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + model_kwargs = self._init_model_kwargs() + else: + input_ids = self.input_ids.gpu[:num_tokens_padded] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_tokens_padded] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] + else: + positions = self.positions.gpu[:num_tokens_padded] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) + + if ubatch_slices_padded is not None: + # Adjust values to reflect a single ubatch. + # TODO(sage,lucas): this is cruft that should be addressed in + # the padding refactor. + num_tokens_padded = ubatch_slices_padded[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_padded + + with ( + self.maybe_randomize_inputs(input_ids, inputs_embeds), + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, + ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, # OMNI: required for KV cache operations + ), + ): + outputs = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # Eagle currently only supports PIECEWISE cudagraphs. + # Therefore only use cudagraphs if the main model uses PIECEWISE + # NOTE(lucas): this is a hack, need to clean up. + use_cudagraphs = ( + (is_graph_capturing and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE) + or (not is_graph_capturing and cudagraph_runtime_mode != CUDAGraphMode.NONE) + ) and not self.speculative_config.enforce_eager + + # Note(gnovack) - We need to disable cudagraphs for one of the two + # lora cases when cudagraph_specialize_lora is enabled. This is a + # short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/28334 + if self.compilation_config.cudagraph_specialize_lora and activate_lora: + use_cudagraphs = False + + self.drafter.dummy_run( + num_tokens, + use_cudagraphs=use_cudagraphs, + is_graph_capturing=is_graph_capturing, + ) + + # We register layerwise NVTX hooks here after the first dynamo tracing is + # done to avoid nvtx operations in hook functions being traced by + # torch dynamo and causing graph breaks. + # Note that for DYNAMO_ONCE and VLLM_COMPILE mode, + # compiled model's dynamo tracing is only done once and the compiled model's + # __call__ function is replaced by calling the compiled function. + # So it's safe to register hooks here. Hooks will be registered to + # both compiled and uncompiled models but they will never + # be called on the compiled model execution path. + self._register_layerwise_nvtx_hooks() + + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + return hidden_states, None + + def profile_run(self) -> None: + # Profile with multimodal encoder & encoder cache. + if self.supports_mm_inputs: + mm_config = self.model_config.multimodal_config + if mm_config is not None and mm_config.skip_mm_profiling: + logger.info("Skipping memory profiling for multimodal encoder and encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None + + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[dummy_modality] + + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) + + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) + + # Run multimodal encoder. + dummy_encoder_outputs = self.model.embed_multimodal(**batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # NOTE: This happens when encoder cache needs to store + # the embeddings that encoder outputs are scattered onto. + # In this case we create dummy embeddings of size + # (max_tokens_for_modality, hidden_size) and scatter + # encoder output into it. + encoder_output_shape = dummy_encoder_outputs[0].shape + max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[dummy_modality] + if encoder_output_shape[0] < max_mm_tokens_per_item: + encoder_hidden_size = encoder_output_shape[-1] + expanded_outputs = [] + for output in dummy_encoder_outputs: + expanded = output.new_zeros((max_mm_tokens_per_item, encoder_hidden_size)) + num_tokens = output.shape[0] + expanded[:num_tokens].copy_(output) + expanded_outputs.append(expanded) + + dummy_encoder_outputs = expanded_outputs + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + # Add `is_profile` here to pre-allocate communication buffers + hidden_states, _ = self._dummy_run(self.max_num_tokens, is_profile=True) + output = None + self._sync_device() + del hidden_states + self.encoder_cache.clear() + gc.collect() diff --git a/vllm_omni/worker/gpu_generation_worker.py b/vllm_omni/worker/gpu_generation_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..91f90d6b41e6cc73a16ee6bc87ab0262ecd2da2c --- /dev/null +++ b/vllm_omni/worker/gpu_generation_worker.py @@ -0,0 +1,103 @@ +import gc +import os + +import torch +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.mem_utils import MemorySnapshot, format_gib +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.utils import report_usage_stats +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +from vllm.v1.worker.gpu_worker import init_worker_distributed_environment +from vllm.v1.worker.utils import request_memory +from vllm.v1.worker.workspace import init_workspace_manager + +from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner +from vllm_omni.worker.mixins import OmniWorkerMixin + +logger = init_logger(__name__) + + +class GPUGenerationWorker(OmniWorkerMixin, GPUWorker): + """GPU Worker for Generation model (non-autoregressive waveform generation). + + Usage in stage config: + worker_cls: "vllm_omni.worker.gpu_generation_model_runner.GPUGenerationModelRunner" + """ + + def init_device(self): + if self.device_config.device_type == "cuda": + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + parallel_config = self.parallel_config + if ( + parallel_config.distributed_executor_backend not in ("ray", "external_launcher") + and parallel_config.data_parallel_backend != "ray" + and parallel_config.nnodes_within_dp == 1 + ): + # Use local DP rank if available, otherwise use global DP rank. + dp_local_rank = self.parallel_config.data_parallel_rank_local + if dp_local_rank is None: + dp_local_rank = self.parallel_config.data_parallel_index + + tp_pp_world_size = ( + self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size + ) + + # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK + self.local_rank += dp_local_rank * tp_pp_world_size + assert self.local_rank < torch.cuda.device_count(), ( + f"DP adjusted local rank {self.local_rank} is out of bounds. " + ) + visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must " + f"be less than or equal to the number of visible devices " + f"({visible_device_count})." + ) + self.device = torch.device(f"cuda:{self.local_rank}") + current_platform.set_device(self.device) + + current_platform.check_if_supports_dtype(self.model_config.dtype) + + # Initialize the distributed environment BEFORE taking + # memory snapshot + # This ensures NCCL buffers are allocated before we measure + # available memory + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) + + # Set random seed. + set_random_seed(self.model_config.seed) + + # Now take memory snapshot after NCCL is initialized + gc.collect() + torch.cuda.empty_cache() + + # take current memory snapshot + self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) + self.requested_memory = request_memory(init_snapshot, self.cache_config) + logger.debug("worker init memory snapshot: %r", self.init_snapshot) + logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory)) + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) + + if self.use_v2_model_runner: + # OMNI: v2 model runner does not yet include omni hooks. + logger.warning("OMNI GPUGenerationWorker forces v1 model runner for omni hooks.") + self.use_v2_model_runner = False + + self.model_runner = GPUGenerationModelRunner(self.vllm_config, self.device) + + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..cb14ed49a3e8670f476190ae873b62c03c1fc2fb --- /dev/null +++ b/vllm_omni/worker/gpu_model_runner.py @@ -0,0 +1,1156 @@ +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +import torch +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.config import CUDAGraphMode +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.interfaces import supports_mrope +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.sampling_params import SamplingType +from vllm.utils.import_utils import LazyLoader +from vllm.utils.math_utils import cdiv +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + xgr_torch_compile = LazyLoader( + "xgr_torch_compile", + globals(), + "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile", + ) + +logger = init_logger(__name__) + + +class OmniGPUModelRunner(GPUModelRunner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._omni_per_req_additional_information: dict[str, dict] | None = None + self._omni_num_scheduled_tokens_np: np.ndarray | None = None + self._omni_last_model_output: object | None = None + + def load_model(self, *args, **kwargs) -> None: + super().load_model(*args, **kwargs) + # TODO move this model specific logic to a separate class + if hasattr(self.model, "talker_mtp") and self.model.talker is not None: + self.talker_mtp = self.model.talker_mtp + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None + if cudagraph_mode.has_full_cudagraphs(): + self.talker_mtp = CUDAGraphWrapper( + self.model.talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) + hidden_size = self.model_config.hf_config.talker_config.text_config.hidden_size + max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size) + self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32) + self.talker_mtp_inputs_embeds = self._make_buffer( + max_batch_size, hidden_size, dtype=self.dtype, numpy=False + ) + self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + + def _init_mrope_positions(self, req_state: CachedRequestState): + """Initialize M-RoPE positions for multimodal inputs. + + Extracts multimodal feature metadata (image grids, video grids, + audio features) and computes M-RoPE positions for proper positional + encoding of multimodal tokens. + + Args: + req_state: Cached request state containing multimodal features + + Raises: + AssertionError: If the model does not support M-RoPE + """ + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) + + if supports_mrope(self.get_model()): + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + mm_features=req_state.mm_features, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint, + # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds + # apart from the forced-preemption case in reset_prefix_cache. And in + # that case we include the resumed_req_ids in the unscheduled set so + # that they get cleared from the persistent batch before being re-scheduled + # in the normal resumed request path. + unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids) + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + reqs_to_add: list[CachedRequestState] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if sampling_params and sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if self.is_pooling_model: + assert pooling_params is not None + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" + + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + req_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, + mm_features=new_req_data.mm_features, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + self.requests[req_id] = req_state + + # If prompt embeddings are provided, decode and attach to inter_data + try: + if getattr(new_req_data, "prompt_embeds", None) is not None: + payload = new_req_data.prompt_embeds + dtype = getattr(np, payload.dtype) + arr = np.frombuffer(payload.data, dtype=dtype) + arr = arr.reshape(payload.shape) + pe_cpu = torch.from_numpy(arr) + # Store temporarily on CPU; later moved to device in builder + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # Also replace payload with Tensor for user visibility in + # scheduler_output + try: + new_req_data.prompt_embeds = pe_cpu # type: ignore[assignment] + except Exception: + pass + except Exception as e: + logger.error(f"Error decoding prompt embeds: {e}") + # Decode additional_information payloads (dictionary) + try: + if getattr(new_req_data, "additional_information", None) is not None: + payload_info = new_req_data.additional_information + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + else: + from vllm_omni.engine import AdditionalInformationPayload + + if isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr.copy()) + else: + info_dict[k] = entry.list_data + if info_dict: + setattr( + self.requests[req_id], + "additional_information_cpu", + info_dict, + ) + except Exception as e: + logger.error(f"Error decoding additional information: {e}") + pass + + if sampling_params and sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = ( + self.input_batch.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(req_state) + + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._init_xdrope_positions(req_state) + + reqs_to_add.append(self.requests[req_id]) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_id in req_data.resumed_req_ids + num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + if req_state.prev_num_draft_len and self.use_async_scheduling: + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt length), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = num_computed_tokens + len(new_token_ids) - req_state.num_tokens + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = self.input_batch.num_prompt_tokens[req_index] + num_output_tokens + self.input_batch.num_tokens_no_spec[req_index] = end_idx + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) + else: + assert req_index is None + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] + + reqs_to_add.append(req_state) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens + if new_block_ids is not None: + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for request in reqs_to_add: + self.input_batch.add_request(request) + self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + @torch.inference_mode() + def extract_multimodal_outputs(self, hidden_states: torch.Tensor | list[torch.Tensor] | OmniOutput) -> dict: + if ( + hasattr(self.model, "have_multimodal_outputs") + and self.model.have_multimodal_outputs + and isinstance(hidden_states, OmniOutput) + ): + text_hidden_states = hidden_states.text_hidden_states + multimodal_outputs = hidden_states.multimodal_outputs + + elif isinstance(hidden_states, torch.Tensor): + text_hidden_states = hidden_states + multimodal_outputs = {} + elif isinstance(hidden_states, list) or isinstance(hidden_states, tuple): + text_hidden_states = hidden_states[0] + multimodal_outputs = {} + else: + raise ValueError(f"Invalid hidden states type: {type(hidden_states)}") + return text_hidden_states, multimodal_outputs + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a dummy forward pass to warm up/profile run or capture the + CUDA graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: used to control the behavior. + - if not set will determine the cudagraph mode based on using + the self.cudagraph_dispatcher. + - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. + - CUDAGraphMode.FULL: Full cudagraph, attention metadata is + needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. + remove_lora: If False, dummy LoRAs are not destroyed after the run + activate_lora: If False, dummy_run is performed without LoRAs. + """ + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # The current dummy run only covers LM execution, so we can skip it. + # mm encoder dummy run may need to add in the future. + return torch.tensor([]), torch.tensor([]) + + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.separate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_tokens_unpadded = int(num_scheduled_tokens.sum()) + + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + + _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, + ) + ) + + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = _cudagraph_mode + else: + assert cudagraph_runtime_mode == _cudagraph_mode, ( + f"Cudagraph runtime mode mismatch in dummy_run. " + f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}." + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + attn_metadata: PerLayerAttnMetadata | None = None + + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs_padded, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + + # If force_attention is True, we always capture attention. Otherwise, + # it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + seq_lens = max_query_len # type: ignore[assignment] + self.seq_lens.np[:num_reqs] = seq_lens + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, + for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, + ) + + with self.maybe_dummy_run_with_lora( + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + remove_lora, + ): + # Make sure padding doesn't exceed max_num_tokens + assert num_tokens_padded <= self.max_num_tokens + model_kwargs = self._init_model_kwargs() + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + + model_kwargs = { + **model_kwargs, + **self._dummy_mm_kwargs(num_reqs), + } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + model_kwargs = self._init_model_kwargs() + else: + input_ids = self.input_ids.gpu[:num_tokens_padded] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_tokens_padded] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] + else: + positions = self.positions.gpu[:num_tokens_padded] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) + + if ubatch_slices_padded is not None: + # Adjust values to reflect a single ubatch. + # TODO(sage,lucas): this is cruft that should be addressed in + # the padding refactor. + num_tokens_padded = ubatch_slices_padded[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_padded + + with ( + self.maybe_randomize_inputs(input_ids, inputs_embeds), + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, + ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, + ), + ): + if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): + num_tokens_padded_talker_mtp = num_tokens_padded + if num_tokens_padded_talker_mtp == self.max_num_tokens: + num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0] + outputs = self.talker_mtp( + self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp], + self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp], + self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp], + self.text_step.gpu[:num_tokens_padded_talker_mtp], + ) + self.compilation_config.cache_dir = None + outputs = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # Eagle currently only supports PIECEWISE cudagraphs. + # Therefore only use cudagraphs if the main model uses PIECEWISE + # NOTE(lucas): this is a hack, need to clean up. + use_cudagraphs = ( + (is_graph_capturing and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE) + or (not is_graph_capturing and cudagraph_runtime_mode != CUDAGraphMode.NONE) + ) and not self.speculative_config.enforce_eager + + # Note(gnovack) - We need to disable cudagraphs for one of the two + # lora cases when cudagraph_specialize_lora is enabled. This is a + # short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/28334 + if self.compilation_config.cudagraph_specialize_lora and activate_lora: + use_cudagraphs = False + + self.drafter.dummy_run( + num_tokens, + use_cudagraphs=use_cudagraphs, + is_graph_capturing=is_graph_capturing, + slot_mappings=slot_mappings, + ) + + # We register layerwise NVTX hooks here after the first dynamo tracing is + # done to avoid nvtx operations in hook functions being traced by + # torch dynamo and causing graph breaks. + # Note that for DYNAMO_ONCE and VLLM_COMPILE mode, + # compiled model's dynamo tracing is only done once and the compiled model's + # __call__ function is replaced by calling the compiled function. + # So it's safe to register hooks here. Hooks will be registered to + # both compiled and uncompiled models but they will never + # be called on the compiled model execution path. + self._register_layerwise_nvtx_hooks() + + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + logit_indices_device = torch.from_numpy(logit_indices).to(self.device, non_blocking=True) + return hidden_states, hidden_states[logit_indices_device] + + def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput") -> None: + """Decode per-request prompt_embeds and additional_information for newly + scheduled requests and store them to CPU in the request state. + This version avoids hard dependency on payload classes by duck-typing.""" + try: + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + if not new_reqs: + return + for nr in new_reqs: + req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) + if req_id is None: + continue + # prompt_embeds + payload_pe = getattr(nr, "prompt_embeds", None) + pe_cpu = None + if payload_pe is not None: + if isinstance(payload_pe, torch.Tensor): + pe_cpu = payload_pe.detach().to("cpu").contiguous() + else: + # Try duck-typing a payload with data/shape/dtype + data = getattr(payload_pe, "data", None) + shape = getattr(payload_pe, "shape", None) + if data is not None and shape is not None: + dt = np.dtype(getattr(payload_pe, "dtype", "float32")) + arr = np.frombuffer(data, dtype=dt) + arr = arr.reshape(shape) + pe_cpu = torch.from_numpy(arr.copy()) + if pe_cpu is not None and req_id in self.requests: + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # additional_information + payload_info = getattr(nr, "additional_information", None) + if payload_info is not None: + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + else: + # Try duck-typing a payload with entries, each entry may have + # tensor_data/tensor_dtype/tensor_shape or list_data + entries = getattr(payload_info, "entries", None) + if isinstance(entries, dict): + for k, entry in entries.items(): + tensor_data = getattr(entry, "tensor_data", None) + if tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(tensor_data, dtype=dt) + arr = arr.reshape(getattr(entry, "tensor_shape", ())) + info_dict[k] = torch.from_numpy(arr.copy()) + else: + info_dict[k] = getattr(entry, "list_data", None) + if info_dict and req_id in self.requests: + setattr(self.requests[req_id], "additional_information_cpu", info_dict) + except Exception as e: + logger.error(f"Error decoding prompt_embeds / additional_information: {e}") + + def _gather_runtime_additional_information(self) -> list[dict]: + """Gather per-request additional_information stored in request state in batch order.""" + per_req_runtime_info = [] + for req_id in self.input_batch.req_ids: + req_state = self.requests.get(req_id) + info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + if info and isinstance(info, dict): + per_req_runtime_info.append(info) + if "thinker_reply_part_per_request" in info: + q = info["thinker_reply_part_per_request"] + if hasattr(q, "shape"): + logger.debug(f"[OMNI] req={req_id} has thinker_reply_part_per_request queue shape: {q.shape}") + else: + per_req_runtime_info.append({}) + return per_req_runtime_info + + def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[int, int]]: + """Compute (start, end) token spans for each request within the flattened step sequence.""" + req_token_spans: list[tuple[int, int]] = [] + for req_index in range(len(self.input_batch.req_ids)): + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + req_token_spans.append((start_offset, start_offset + sched_tokens)) + return req_token_spans + + def _build_model_kwargs_extra(self) -> dict: + """Build extra keyword arguments passed to the model for this step, including: + - runtime_additional_information: per-request additional information stored in request state + """ + model_kwargs_extra: dict[str, object] = {} + try: + model_kwargs_extra["runtime_additional_information"] = self._gather_runtime_additional_information() + except Exception as e: + logger.error(f"[OMNI DEBUG] Error building model_kwargs_extra: {e}") + import traceback + + traceback.print_exc() + return model_kwargs_extra + + def _process_additional_information_updates( + self, + hidden_states: torch.Tensor, + multimodal_outputs: object, + num_scheduled_tokens_np: np.ndarray, + scheduler_output: "SchedulerOutput", + ) -> None: + """Process model-provided per-request additional_information updates and merge into request state.""" + try: + # execute the custom postprocess function + # TODO(Peiqi): do we have a more elegant way to do this? + if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: + for req_index, req_id in enumerate(self.input_batch.req_ids): + if self.model_config.async_chunk: + req_infos = self._get_additional_information(scheduler_output, req_id) + else: + req_state = self.requests.get(req_id) + req_infos = ( + getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + # only consider to store data into update dict. + hidden_states_slice = hidden_states[s:e] + update_dict = self.model.postprocess(hidden_states_slice, **req_infos) + self._merge_additional_information_update(req_id, update_dict) + except Exception as e: + logger.error( + f"Error merging for requests:{self.input_batch.req_ids} " + f"additional information update: {e}, with the multimodal_outputs " + f"as {multimodal_outputs}" + ) + import traceback + + traceback.print_exc() + + def _collect_additional_information_for_prefill( + self, + num_scheduled_tokens_np: np.ndarray, + ) -> dict[str, dict]: + """Overlay per-request prompt_embeds for the prefill portion and collect + additional_information slices for this step. Returns a map req_id -> dict.""" + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + if overlay_len > 0 and pe_cpu is not None: + src = pe_cpu[num_computed_tokens : num_computed_tokens + overlay_len].to( + dtype=self.dtype, device=self.device, non_blocking=True + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src) + + def _get_additional_information(self, scheduler_output: "SchedulerOutput", req_id: str) -> dict: + req_infos = None + req_state = self.requests.get(req_id) + additional_information_cpu = getattr(req_state, "additional_information_cpu", None) + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id == req_id: + payload_info = getattr(new_req, "additional_information", None) + if payload_info is not None: + return payload_info + + if hasattr(scheduler_output.scheduled_cached_reqs, "additional_information"): + cached_infos = getattr(scheduler_output.scheduled_cached_reqs, "additional_information", {}) + if isinstance(cached_infos, dict) and req_id in cached_infos: + req_infos = cached_infos[req_id] + if not isinstance(req_infos, dict): + req_infos = None + + if req_infos is None or req_infos.get("last_talker_hidden", None) is None: + if req_infos is None: + additional_information_cpu.pop("thinker_embeddings", None) + req_infos = additional_information_cpu + else: + req_infos["last_talker_hidden"] = additional_information_cpu.get("last_talker_hidden", None) + req_infos["num_processed_thinker_tokens"] = additional_information_cpu.get( + "num_processed_thinker_tokens", 0 + ) + if not isinstance(req_infos, dict): + req_infos = None + + if req_infos is None: + logger.warning(f"No additional_information found for req_id: {req_id}") + + return req_infos + + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + num_input_tokens: int, + intermediate_tensors: IntermediateTensors | None = None, + ): + """Align with v0.14.0 preprocess and omni's additional information handling.""" + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + is_first_rank = get_pp_group().is_first_rank + is_encoder_decoder = self.model_config.is_encoder_decoder + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + ec_connector_output = None + + if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder: + # Run the multimodal encoder if any. + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + inputs_embeds_scheduled = self.model.embed_input_ids( + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) + + input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens) + model_kwargs = { + **self._init_model_kwargs(), + **self._extract_mm_kwargs(scheduler_output), + } + elif self.enable_prompt_embeds and is_first_rank: + # Get the input embeddings for the tokens that are not input embeds, + # then put them into the appropriate positions. + # TODO(qthequartermasterman): Since even when prompt embeds are + # enabled, (a) not all requests will use prompt embeds, and (b) + # after the initial prompt is processed, the rest of the generated + # tokens will be token ids, it is not desirable to have the + # embedding layer outside of the CUDA graph all the time. The v0 + # engine avoids this by "double compiling" the CUDA graph, once + # with input_ids and again with inputs_embeds, for all num_tokens. + # If a batch only has token ids, then including the embedding layer + # in the CUDA graph will be more performant (like in the else case + # below). + token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens].nonzero(as_tuple=False).squeeze(1) + # Some tokens ids may need to become embeds + if token_ids_idx.numel() > 0: + token_ids = self.input_ids.gpu[token_ids_idx] + tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids) + self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs() + input_ids = self.input_ids.gpu[:num_input_tokens] + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids.gpu[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs() + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_input_tokens] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_input_tokens] + else: + positions = self.positions.gpu[:num_input_tokens] + + if is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True + ) + + if is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: + # Run the encoder, just like we do with other multimodal inputs. + # For an encoder-decoder model, our processing here is a bit + # simpler, because the outputs are just passed to the decoder. + # We are not doing any prompt replacement. We also will only + # ever have a single encoder input. + encoder_outputs = self._execute_mm_encoder(scheduler_output) + model_kwargs.update({"encoder_outputs": encoder_outputs}) + + req_ids = self.input_batch.req_ids + num_scheduled_tokens_np = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], + dtype=np.int32, + ) + self._omni_num_scheduled_tokens_np = num_scheduled_tokens_np + + # Note: only prefill need collect additional_information for now. + # Decode don't need per_req_additional_information anymore. + if inputs_embeds is not None: + # Prefill: overlay prompt_embeds and collect additional_information + self._collect_additional_information_for_prefill(num_scheduled_tokens_np) + + if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: + # Overlay custom prompt_embeds per request for the prompt portion; + # collect additional_information (tensor/list) for prefill portion only + decode_req_ids = [] + for req_index, req_id in enumerate(self.input_batch.req_ids): + # Try to get additional_information from multiple sources + if self.vllm_config.model_config.async_chunk: + req_infos = self._get_additional_information(scheduler_output, req_id) + else: + req_state = self.requests.get(req_id) + req_infos = ( + getattr(req_state, "additional_information_cpu", None) if req_state is not None else None + ) + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + span_len = int(e) - int(s) + + # call the custom process function + req_input_ids, req_embeds, update_dict = self.model.preprocess( + input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos + ) + if hasattr(self.model, "talker_mtp") and span_len == 1: + last_talker_hidden, text_step = update_dict.pop("mtp_inputs") + decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) + self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) + self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) + self.text_step.gpu[decode_slice].copy_(text_step) + decode_req_ids.append(req_id) + + # TODO(Peiqi): the merge stage could move out from the critical path + self._merge_additional_information_update(req_id, update_dict) + + # update the inputs_embeds and input_ids + seg_len = min(span_len, req_embeds.shape[0]) + inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] + if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: + input_ids[s : s + seg_len] = req_input_ids + + # run talker mtp decode + if hasattr(self.model, "talker_mtp"): + self._talker_mtp_forward(decode_req_ids, inputs_embeds) + + return ( + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ec_connector_output, + ) + + def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: + decode_batch_size = len(decode_req_ids) + if decode_batch_size == 0: + return + _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( + num_tokens=decode_batch_size, + num_reqs=decode_batch_size, + num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + num_tokens_padded = batch_desc.num_tokens + req_input_ids = self.talker_mtp_input_ids.gpu[:num_tokens_padded] + req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded] + last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded] + text_step = self.text_step.gpu[:num_tokens_padded] + with set_forward_context( + None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc + ): + req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # update the inputs_embeds and code_predictor_codes + code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + for idx, req_id in enumerate(decode_req_ids): + req_index = self.input_batch.req_ids.index(req_id) + start_offset = int(self.query_start_loc.cpu[req_index]) + inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] + update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + self._merge_additional_information_update(req_id, update_dict) + + def _model_forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], + ): + """Inject omni-specific kwargs into forward and cache model output""" + model_kwargs_extra = self._build_model_kwargs_extra() + + runtime_info = model_kwargs_extra.get("runtime_additional_information", []) + if runtime_info: + for i, info in enumerate(runtime_info): + if info: + logger.debug(f"[OMNI] req[{i}] runtime_additional_information keys: {list(info.keys())}") + + model_output = super()._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + **model_kwargs_extra, + ) + if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"): + model_output = self.model.make_omni_output(model_output, **model_kwargs_extra) + # Cache model output so later sample_tokens can consume multimodal results. + self._omni_last_model_output = model_output + return model_output + + def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: + req_state = self.requests.get(req_id) + if req_state is None: + return + existing = getattr(req_state, "additional_information_cpu", {}) + if not isinstance(existing, dict): + existing = {} + merged = dict(existing) + for k, v in upd.items(): + if isinstance(v, torch.Tensor): + merged[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + merged[k] = [ + (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v + ] + else: + merged[k] = v + setattr(req_state, "additional_information_cpu", merged) diff --git a/vllm_omni/worker/mixins.py b/vllm_omni/worker/mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..5b25b8362b91d009c05509461a74cb9ba087f93c --- /dev/null +++ b/vllm_omni/worker/mixins.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import Any + + +class OmniWorkerMixin: + """Mixin to ensure Omni plugins are loaded in worker processes.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins()