diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 7045d8810493e5c79d670a11401b99cf16268a2e..bbed80ebe84762412eb49b286d9007223e060085 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -8,7 +8,8 @@ template = """

Links for vLLM

- {wheel}
+ {x86_wheel}
+ {arm_wheel}
""" @@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel) with open("index.html", "w") as f: print(f"Generated index.html for {args.wheel}") + # sync the abi tag with .buildkite/scripts/upload-wheels.sh + if "x86_64" in filename: + x86_wheel = filename + arm_wheel = filename.replace("x86_64", "aarch64").replace( + "manylinux1", "manylinux2014" + ) + elif "aarch64" in filename: + x86_wheel = filename.replace("aarch64", "x86_64").replace( + "manylinux2014", "manylinux1" + ) + arm_wheel = filename + else: + raise ValueError(f"Unsupported wheel: {filename}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + template.format( + x86_wheel=x86_wheel, + x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), + arm_wheel=arm_wheel, + arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), + ) ) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml deleted file mode 100644 index 56ec933c9cc0e5e1fc8041db7f485fa272575d20..0000000000000000000000000000000000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# For vllm script, with -t option (tensor parallel size). -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 -model_name: "HandH1998/QQQ-Llama-3-8b-g128" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.419 - - name: "exact_match,flexible-extract" - value: 0.416 -limit: 1000 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 27a1a9a82bd352623c44728e4480ee47209bd9f0..37eeac85c933b8a5a077364d0566772d2c592208 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh index a67fc89d54e604b254e1d0c6fdbd5e6fc86ef939..897f84d1e360de11ceb10d77baf0ff9f8453cdfd 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -2,7 +2,7 @@ # We can use this script to compute baseline accuracy on GSM for transformers. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index b98d42aa7b822481d6b03448a4ed3b4ffe49c891..792f355c47a5178801b2624f1a9e06c69707f0ce 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index b39f9899a8f284d1d6e1fb8f91800954f2c5fbc5..e6f5c8b60f459a1d25d45493807e407ffca5bab3 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -141,7 +141,7 @@ When run, benchmark script generates results under `benchmark/results` folder, a `compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. If only one benchmark_results.json is passed, `compare-json-results.py` compares different TP and PP configurations in the benchmark_results.json instead. -Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output lenght, max concurrency and qps. +Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output length, max concurrency and qps. `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` | | Model | Dataset Name | Input Len | Output Len | # of max concurrency | qps | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio | diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 8afde017d383e3e72c9f1171759ddf581661d8d8..37e2980eea974a5a6bf35b92aabf1c56d2f75819 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - SGLang: `lmsysorg/sglang:v0.3.2-cu121` - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` - - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.* - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. - Hardware - 8x Nvidia A100 GPUs diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py index 12c4ba6aa69a60be66396fdb63265f79675bed4f..50431d0cd4c5e69362f4047918ec42edea497ae9 100644 --- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -3,44 +3,129 @@ import argparse import json import os +from importlib import util import pandas as pd +plotly_found = util.find_spec("plotly.express") is not None + def compare_data_columns( files, name_column, data_column, info_cols, drop_column, debug=False ): - print("\ncompare_data_column: " + data_column) + """ + Align concatenation by keys derived from info_cols instead of row order. + - Pick one canonical key list: subset of info_cols present in ALL files. + - For each file: set index to those keys, aggregate duplicates + - (mean for metric, first for names). + - Concat along axis=1 (indexes align), then reset_index so callers can + - group by columns. + - If --debug, add a _name column per file. + """ + print("\ncompare_data_column:", data_column) + frames = [] raw_data_cols = [] compare_frames = [] + + # 1) choose a canonical key list from info_cols that exists in ALL files + cols_per_file = [] + for f in files: + try: + df_tmp = pd.read_json(f, orient="records") + except Exception as err: + raise ValueError(f"Failed to read {f}") from err + cols_per_file.append(set(df_tmp.columns)) + + key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)] + if not key_cols: + # soft fallback: use any info_cols present in the first file + key_cols = [c for c in info_cols if c in list(cols_per_file[0])] + if not key_cols: + raise ValueError( + "No common key columns found from info_cols across the input files." + ) + + # 2) build a single "meta" block (keys as columns) once, aligned by the key index + meta_added = False + for file in files: - data_df = pd.read_json(file) - serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) - # Show all info columns in the first couple columns - if not frames: - for col in info_cols: - if col not in serving_df.columns: - print(f"Skipping missing column: {col}") - continue - frames.append(serving_df[col]) - # only show test name under debug mode - if debug is True: - serving_df = serving_df.rename(columns={name_column: file + "_name"}) - frames.append(serving_df[file + "_name"]) - - file = "/".join(file.split("/")[:-1]) - serving_df = serving_df.rename(columns={data_column: file}) - frames.append(serving_df[file]) - raw_data_cols.append(file) - compare_frames.append(serving_df[file]) + df = pd.read_json(file, orient="records") + + # Keep rows that actually have the compared metric (same as original behavior) + if drop_column in df.columns: + df = df.dropna(subset=[drop_column], ignore_index=True) + + # Stabilize numeric key columns (harmless if missing) + for c in ( + "Input Len", + "Output Len", + "TP Size", + "PP Size", + "# of max concurrency.", + "qps", + ): + if c in df.columns: + df[c] = pd.to_numeric(df[c], errors="coerce") + + # Ensure all key columns exist + for c in key_cols: + if c not in df.columns: + df[c] = pd.NA + + # Set index = key_cols and aggregate duplicates → unique MultiIndex + df_idx = df.set_index(key_cols, drop=False) + + # meta (key columns), unique per key + meta = df_idx[key_cols] + if not meta.index.is_unique: + meta = meta.groupby(level=key_cols, dropna=False).first() + + # metric series for this file, aggregated to one row per key + file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file) + s = df_idx[data_column] + if not s.index.is_unique: + s = s.groupby(level=key_cols, dropna=False).mean() + s.name = file_label # column label like original + + # add meta once (from first file) so keys are the leftmost columns + if not meta_added: + frames.append(meta) + meta_added = True + + # (NEW) debug: aligned test-name column per file + if debug and name_column in df_idx.columns: + name_s = df_idx[name_column] + if not name_s.index.is_unique: + name_s = name_s.groupby(level=key_cols, dropna=False).first() + name_s.name = f"{file_label}_name" + frames.append(name_s) + + frames.append(s) + raw_data_cols.append(file_label) + compare_frames.append(s) + + # Generalize ratio: for any file N>=2, add ratio (fileN / file1) if len(compare_frames) >= 2: - # Compare numbers among two files - ratio_df = compare_frames[1] / compare_frames[0] - frames.append(ratio_df) - compare_frames.pop(1) + base = compare_frames[0] + current = compare_frames[-1] + ratio = current / base + ratio = ratio.mask(base == 0) # avoid inf when baseline is 0 + ratio.name = f"Ratio 1 vs {len(compare_frames)}" + frames.append(ratio) + # 4) concat on columns with aligned MultiIndex; + # then reset_index to return keys as columns concat_df = pd.concat(frames, axis=1) + concat_df = concat_df.reset_index(drop=True).reset_index() + if "index" in concat_df.columns: + concat_df = concat_df.drop(columns=["index"]) + + # Ensure key/info columns appear first (in your info_cols order) + front = [c for c in info_cols if c in concat_df.columns] + rest = [c for c in concat_df.columns if c not in front] + concat_df = concat_df[front + rest] + print(raw_data_cols) return concat_df, raw_data_cols @@ -67,6 +152,15 @@ def split_json_by_tp_pp( df = pd.DataFrame(data) + # Keep only "serving" tests + name_col = next( + (c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None + ) + if name_col: + df = df[ + df[name_col].astype(str).str.contains(r"serving", case=False, na=False) + ].copy() + # Handle alias column names rename_map = { "tp_size": "TP Size", @@ -181,7 +275,6 @@ if __name__ == "__main__": f"Expected subset: {filtered_info_cols}, " f"but DataFrame has: {list(output_df.columns)}" ) - output_df_sorted = output_df.sort_values(by=existing_group_cols) output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False) for name, group in output_groups: @@ -189,8 +282,7 @@ if __name__ == "__main__": text_file.write(html_msgs_for_data_cols[i]) text_file.write(html) - if plot is True: - import pandas as pd + if plot and plotly_found: import plotly.express as px df = group[raw_data_cols] diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh index 06d7b5ed484da2c9e9747316a2460bb67ce6c407..a00de940cbbb81bff911996a7ec56d0c83c8470d 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh @@ -382,7 +382,7 @@ run_genai_perf_tests() { client_command="genai-perf profile \ -m $model \ --service-kind openai \ - --backend vllm \ + --backend "$backend" \ --endpoint-type chat \ --streaming \ --url localhost:$port \ diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 85d3e5638742180ce151769bd462627241ead66b..92a1bcada3879bbd5299b7359aa94e4c6648142f 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -7,7 +7,7 @@ steps: commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -27,7 +27,12 @@ steps: env: DOCKER_BUILDKIT: "1" + - block: "Build CUDA 12.6 wheel" + key: block-build-cu126-wheel + depends_on: ~ + - label: "Build wheel - CUDA 12.6" + depends_on: block-build-cu126-wheel id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge @@ -57,23 +62,45 @@ steps: env: DOCKER_BUILDKIT: "1" - - block: "Build release image" + - label: "Build release image (x86)" depends_on: ~ - key: block-release-image-build - - - label: "Build release image" - depends_on: block-release-image-build - id: build-release-image + id: build-release-image-x86 agents: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + # re-tag to default image tag and push, just in case arm64 build fails + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Build release image (arm64)" + depends_on: ~ + id: build-release-image-arm64 + agents: + queue: arm64_cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + + # Add job to create multi-arch manifest + - label: "Create multi-arch manifest" + depends_on: + - build-release-image-x86 + - build-release-image-arm64 + id: create-multi-arch-manifest + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "docker manifest create public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64 public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64 --amend" + - "docker manifest push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Annotate release workflow" depends_on: - - build-release-image + - create-multi-arch-manifest - build-wheel-cuda-12-8 - build-wheel-cuda-12-6 - build-wheel-cuda-11-8 diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index df0bae0c9cbff8f93e3c5a238477b935eaaca9f9..c395011a244853bb89f8ad4102f4753b8726ddea 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -164,7 +164,6 @@ if [[ $commands == *" entrypoints/llm "* ]]; then --ignore=entrypoints/llm/test_chat.py \ --ignore=entrypoints/llm/test_accuracy.py \ --ignore=entrypoints/llm/test_init.py \ - --ignore=entrypoints/llm/test_generate_multiple_loras.py \ --ignore=entrypoints/llm/test_prompt_validation.py "} fi diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index bbce7a25f97d61e95e2f496aac4dbc28dbb6c2f8..0f734763f13fda95b682eb95a00b2c5b8792895f 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -25,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -46,21 +46,26 @@ function cpu_tests() { set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run kernel tests + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -x -v -s tests/kernels/test_onednn.py" + # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e # Note: disable until supports V1 - # pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model - # pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_cache.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model # Note: disable Bart until supports V1 - pytest -v -s tests/models/language/generation -m cpu_model \ + pytest -x -v -s tests/models/language/generation -m cpu_model \ --ignore=tests/models/language/generation/test_bart.py - VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \ + VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model \ --ignore=tests/models/language/generation/test_bart.py - pytest -v -s tests/models/language/pooling -m cpu_model - pytest -v -s tests/models/multimodal/generation \ + pytest -x -v -s tests/models/language/pooling -m cpu_model + pytest -x -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ --ignore=tests/models/multimodal/generation/test_pixtral.py \ -m cpu_model" @@ -68,35 +73,51 @@ function cpu_tests() { # Run compressed-tensor test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]" # Note: disable it until supports V1 # Run AWQ test - docker exec cpu-test-"$NUMA_NODE" bash -c " - set -e - VLLM_USE_V1=0 pytest -s -v \ - tests/quantization/test_ipex_quant.py" + # docker exec cpu-test-"$NUMA_NODE" bash -c " + # set -e + # VLLM_USE_V1=0 pytest -x -s -v \ + # tests/quantization/test_ipex_quant.py" # Run multi-lora tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/lora/test_qwen2vl.py" - # online serving + # online serving: tp+pp docker exec cpu-test-"$NUMA_NODE" bash -c ' set -e VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & + server_pid=$! + timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 + vllm bench serve \ + --backend vllm \ + --dataset-name random \ + --model meta-llama/Llama-3.2-3B-Instruct \ + --num-prompts 20 \ + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' + + # online serving: tp+dp + docker exec cpu-test-"$NUMA_NODE" bash -c ' + set -e + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 & + server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ - --endpoint /v1/completions' + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' } # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh index b571618f48c2bf2ed901285422aa2fa2aee6d558..1073a4ee30afa19b5b94f7fefea9a36b1f6d22d3 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh @@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off hf-transfer echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index d55a786e41e8b63a0b230d64a67d9a7b17466aa8..505664f3aecd037f45634c5b3ea5c42785201076 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off hf-transfer echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 6256677481ae4f669b011d32c467795f791a0ca3..73f3e63fbf5f6dac30e03a6aad910cb041620eec 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -23,12 +23,15 @@ docker run \ --device /dev/dri \ -v /dev/dri/by-path:/dev/dri/by-path \ --entrypoint="" \ + -e "HF_TOKEN=${HF_TOKEN}" \ + -e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \ --name "${container_name}" \ "${image_name}" \ - sh -c ' - VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m - VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 + bash -c ' + set -e + echo $ZE_AFFINITY_MASK VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp cd tests @@ -37,8 +40,8 @@ docker run \ pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py + pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py pytest -v -s v1/test_serial_utils.py pytest -v -s v1/test_utils.py pytest -v -s v1/test_metrics_reader.py diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh index 209d9c4341cdd83033a92fb878ceb8e6b13d5298..740d81fb39bb0be2c73d90fedfc8701766ca393c 100755 --- a/.buildkite/scripts/tpu/cleanup_docker.sh +++ b/.buildkite/scripts/tpu/cleanup_docker.sh @@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then # Remove dangling images (those that are not tagged and not used by any container) docker image prune -f # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune --force --filter "until=72h" --all + docker volume prune -f && docker system prune --force --filter "until=24h" --all echo "Docker images and volumes cleanup completed." else echo "Disk usage is below $threshold%. No cleanup needed." diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 037897e53dbef42d43fc7656c3d6caf8c18fbe9a..745f285c008ad846051f85e88972e6525869739c 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -14,8 +14,19 @@ fi # Get the single wheel file wheel="${wheel_files[0]}" -# Rename 'linux' to 'manylinux1' in the wheel filename -new_wheel="${wheel/linux/manylinux1}" +# Detect architecture and rename 'linux' to appropriate manylinux version +arch=$(uname -m) +if [[ $arch == "x86_64" ]]; then + manylinux_version="manylinux1" +elif [[ $arch == "aarch64" ]]; then + manylinux_version="manylinux2014" +else + echo "Warning: Unknown architecture $arch, using manylinux1 as default" + manylinux_version="manylinux1" +fi + +# Rename 'linux' to the appropriate manylinux version in the wheel filename +new_wheel="${wheel/linux/$manylinux_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4fc8857854927938e833032f8d4e9188344a1018..55349e0ac9321df1dd343e77d0a7ee49bf80991d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -88,15 +88,6 @@ steps: - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Chunked Prefill Test - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/basic_correctness/test_chunked_prefill - commands: - - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - label: Core Test # 10min mirror_hardwares: [amdexperimental] fast_check: true @@ -118,10 +109,9 @@ steps: - tests/entrypoints/offline_mode commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Entrypoints Test (API Server) # 40min @@ -135,7 +125,8 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ + - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py - pytest -v -s entrypoints/test_chat_utils.py - label: Distributed Tests (4 GPUs) # 10min @@ -242,16 +233,34 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: V1 Test +- label: V1 Test e2e + engine mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/v1 commands: - # split the test to avoid interference - - pytest -v -s v1/core + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e - pytest -v -s v1/engine + +- label: V1 Test entrypoints + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: - pytest -v -s v1/entrypoints + +- label: V1 Test others + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # split the test to avoid interference + - pytest -v -s v1/core + - pytest -v -s v1/executor - pytest -v -s v1/sample - pytest -v -s v1/logits_processors - pytest -v -s v1/worker @@ -263,9 +272,6 @@ steps: - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_metrics_reader.py - # TODO: accuracy does not match, whether setting - # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - - pytest -v -s v1/e2e # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine @@ -295,15 +301,6 @@ steps: - python3 offline_inference/basic/score.py - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 -- label: Prefix Caching Test # 9min - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/prefix_caching - commands: - - pytest -v -s prefix_caching - - - label: Platform Tests (CUDA) mirror_hardwares: [amdexperimental] source_file_dependencies: @@ -328,7 +325,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py parallelism: 4 - label: PyTorch Compilation Unit Tests @@ -345,6 +342,7 @@ steps: - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py + - pytest -v -s compile/test_decorator.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental] @@ -358,6 +356,7 @@ steps: - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py - pytest -v -s compile/piecewise/test_full_cudagraph.py + - pytest -v -s compile/piecewise/test_multiple_graphs.py - label: PyTorch Fullgraph Test # 18min mirror_hardwares: [amdexperimental] @@ -404,6 +403,7 @@ steps: - csrc/moe/ - tests/kernels/moe - vllm/model_executor/layers/fused_moe/ + - vllm/distributed/device_communicators/ commands: - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 @@ -462,19 +462,17 @@ steps: - tests/quantization commands: # temporary install here since we need nightly, will move to requirements/test.in - # after torchao 0.12 release - - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + # after torchao 0.12 release, and pin a working version of torchao nightly here + - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization commands: - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness mirror_hardwares: [amdexperimental] @@ -562,6 +560,14 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' +- label: Multi-Modal Processor Test + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing + - label: Multi-Modal Models Test (Standard) mirror_hardwares: [amdexperimental] torch_nightly: true @@ -571,9 +577,7 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - - pytest -v -s models/multimodal/processing - - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model - - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn" + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 @@ -584,7 +588,7 @@ steps: - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing - label: Multi-Modal Models Test (Extended) 2 mirror_hardwares: [amdexperimental] @@ -647,8 +651,10 @@ steps: - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/fusion.py + - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -660,11 +666,17 @@ steps: # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern + - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py ##### 1 GPU test ##### ##### multi gpus test ##### @@ -757,6 +769,11 @@ steps: - pytest -v -s plugins_tests/test_platform_plugins.py - pip uninstall vllm_add_dummy_platform -y # end platform plugin tests + # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin + - pip install -e ./plugins/prithvi_io_processor_plugin + - pytest -v -s plugins_tests/test_io_processor_plugins.py + - pip uninstall prithvi_io_processor_plugin -y + # end io_processor plugins test # other tests continue here: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model @@ -793,13 +810,14 @@ steps: # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - - pytest -v -s -x lora/test_multi_loras_with_tp.py + - pytest -v -s -x lora/test_llm_with_multi_loras.py - label: Weight Loading Multiple GPU Test # 33min mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" - num_gpus: 2 + num_gpus: 2 + optional: true source_file_dependencies: - vllm/ - tests/weight_loading @@ -847,3 +865,10 @@ steps: commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +- label: Qwen MoE EP Test # optional + gpu: h200 + optional: true + num_gpus: 2 + commands: + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b0dd5e99d4c7278b8af61c21f2fbd983b3aacab1..c087fd555c661e5d23ce16a19cda3c75900accbf 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,6 +10,7 @@ /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 +/vllm/model_executor/layers/mamba @tdoublep /vllm/multimodal @DarkLight1337 @ywang96 /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee @@ -25,11 +26,11 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat /vllm/v1/structured_output @mgoin @russellb @aarnphm +/vllm/v1/attention/backends/triton_attn.py @tdoublep # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo /tests/async_engine @njhill @robertgshaw2-redhat @simon-mo -/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao @@ -44,6 +45,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/structured_output @mgoin @russellb @aarnphm /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee +/tests/models/language/generation/test_hybrid.py @tdoublep # Docs /docs @hmellor @@ -72,3 +74,15 @@ mkdocs.yaml @hmellor /vllm/model_executor/models/pixtral*.py @patrickvonplaten /vllm/transformers_utils/configs/mistral.py @patrickvonplaten /vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten + +# Kernels +/vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep +/vllm/attention/ops/triton_unified_attention.py @tdoublep + +# ROCm related: specify owner with write access to notify AMD folks for careful code review +/docker/Dockerfile.rocm* @gshtras +/vllm/v1/attention/backends/rocm*.py @gshtras +/vllm/v1/attention/backends/mla/rocm*.py @gshtras +/vllm/attention/ops/rocm*.py @gshtras +/vllm/model_executor/layers/fused_moe/rocm*.py @gshtras + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1b30c1292df85ddcf4317994b074d16f142fbc77..8043df65d5585a946779b6ef86a2e4cb7b98effb 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,8 +7,6 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT ## Test Result -## (Optional) Documentation Update - ---
Essential Elements of an Effective PR Description Checklist @@ -17,6 +15,7 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT - [ ] The test plan, such as providing test command. - [ ] The test results, such as pasting the results comparison before and after, or e2e results - [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model. +- [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0).
**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) diff --git a/.github/scale-config.yml b/.github/scale-config.yml new file mode 100644 index 0000000000000000000000000000000000000000..c41a3ee3eb196049de99bdeacb8e2051ec0009ca --- /dev/null +++ b/.github/scale-config.yml @@ -0,0 +1,21 @@ +# scale-config.yml: +# Powers what instance types are available for GHA auto-scaled +# runners. Runners listed here will be available as self hosted +# runners, configuration is directly pulled from the main branch. +# runner_types: +# runner_label: +# instance_type: m4.large +# os: linux +# # min_available defaults to the global cfg in the ALI Terraform +# min_available: undefined +# # when max_available value is not defined, no max runners is enforced +# max_available: undefined +# disk_size: 50 +# is_ephemeral: true + +runner_types: + linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: true + os: linux diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml new file mode 100644 index 0000000000000000000000000000000000000000..e0ab3872d8fa377f383b5677b3430cc636335d13 --- /dev/null +++ b/.github/workflows/issue_autolabel.yml @@ -0,0 +1,309 @@ +name: Label issues based on keywords +on: + issues: + types: [opened, edited, reopened] +permissions: + issues: write # needed so the workflow can add labels + contents: read +concurrency: + group: issue-labeler-${{ github.event.issue.number }} + cancel-in-progress: true +jobs: + add-labels: + runs-on: ubuntu-latest + steps: + - name: Label issues based on keywords + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + // Configuration: Add new labels and keywords here + const labelConfig = { + rocm: { + // Keyword search - matches whole words only (with word boundaries) + keywords: [ + { + term: "composable kernel", + searchIn: "both" + }, + { + term: "rccl", + searchIn: "body" // only search in body + }, + { + term: "migraphx", + searchIn: "title" // only search in title + }, + { + term: "hipgraph", + searchIn: "both" + }, + { + term: "ROCm System Management Interface", + searchIn: "body" + }, + ], + + // Substring search - matches anywhere in text (partial matches) + substrings: [ + { + term: "VLLM_ROCM_", + searchIn: "both" + }, + { + term: "aiter", + searchIn: "title" + }, + { + term: "rocm", + searchIn: "title" + }, + { + term: "amd", + searchIn: "title" + }, + { + term: "hip-", + searchIn: "both" + }, + { + term: "gfx", + searchIn: "both" + }, + { + term: "cdna", + searchIn: "both" + }, + { + term: "rdna", + searchIn: "both" + }, + { + term: "torch_hip", + searchIn: "body" // only in body + }, + { + term: "_hip", + searchIn: "both" + }, + { + term: "hip_", + searchIn: "both" + }, + + // ROCm tools and libraries + { + term: "hipify", + searchIn: "both" + }, + ], + + // Regex patterns - for complex pattern matching + regexPatterns: [ + { + pattern: "\\bmi\\d{3}[a-z]*\\b", + description: "AMD GPU names (mi + 3 digits + optional letters)", + flags: "gi", + searchIn: "both" // "title", "body", or "both" + } + ], + }, + }; + + // Helper function to create regex based on search type + function createSearchRegex(term, type) { + // Escape special regex characters in the term + const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + + switch (type) { + case 'keyword': + // Word boundary search - matches whole words only + return new RegExp(`\\b${escapedTerm}\\b`, "gi"); + case 'substring': + // Substring search - matches anywhere in the text + return new RegExp(escapedTerm, "gi"); + default: + throw new Error(`Unknown search type: ${type}`); + } + } + + // Helper function to find matching terms in text with line information + function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') { + const matches = []; + const lines = text.split('\n'); + + for (const termConfig of searchTerms) { + let regex; + let term, searchIn, pattern, description, flags; + + // Handle different input formats (string or object) + if (typeof termConfig === 'string') { + term = termConfig; + searchIn = 'both'; // default + } else { + term = termConfig.term; + searchIn = termConfig.searchIn || 'both'; + pattern = termConfig.pattern; + description = termConfig.description; + flags = termConfig.flags; + } + + // Skip if this term shouldn't be searched in the current location + if (searchIn !== 'both' && searchIn !== searchLocation) { + continue; + } + + // Create appropriate regex + if (searchType === 'regex') { + regex = new RegExp(pattern, flags || "gi"); + } else { + regex = createSearchRegex(term, searchType); + } + + const termMatches = []; + + // Check each line for matches + lines.forEach((line, lineIndex) => { + const lineMatches = line.match(regex); + if (lineMatches) { + lineMatches.forEach(match => { + termMatches.push({ + match: match, + lineNumber: lineIndex + 1, + lineContent: line.trim(), + searchType: searchType, + searchLocation: searchLocation, + originalTerm: term || pattern, + description: description, + // Show context around the match in the line + context: line.length > 100 ? + line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30), + line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...' + : line.trim() + }); + }); + } + }); + + if (termMatches.length > 0) { + matches.push({ + term: term || (description || pattern), + searchType: searchType, + searchLocation: searchLocation, + searchIn: searchIn, + pattern: pattern, + matches: termMatches, + count: termMatches.length + }); + } + } + + return matches; + } + + // Helper function to check if label should be added + async function processLabel(labelName, config) { + const body = context.payload.issue.body || ""; + const title = context.payload.issue.title || ""; + + core.notice(`Processing label: ${labelName}`); + core.notice(`Issue Title: "${title}"`); + core.notice(`Issue Body length: ${body.length} characters`); + + let shouldAddLabel = false; + let allMatches = []; + let reason = ''; + + const keywords = config.keywords || []; + const substrings = config.substrings || []; + const regexPatterns = config.regexPatterns || []; + + core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`); + + // Search in title + if (title.trim()) { + core.notice(`Searching in title: "${title}"`); + + const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title'); + const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title'); + const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title'); + + allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches); + } + + // Search in body + if (body.trim()) { + core.notice(`Searching in body (${body.length} characters)`); + + const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body'); + const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body'); + const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body'); + + allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches); + } + + if (allMatches.length > 0) { + core.notice(`Found ${allMatches.length} matching term(s):`); + + for (const termMatch of allMatches) { + const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body'; + const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn; + + if (termMatch.searchType === 'regex') { + core.notice(` 📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } else { + core.notice(` 📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } + + // Show details for each match + termMatch.matches.forEach((match, index) => { + core.notice(` ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`); + if (match.description) { + core.notice(` Description: ${match.description}`); + } + core.notice(` Context: ${match.context}`); + if (match.lineContent !== match.context) { + core.notice(` Full line: ${match.lineContent}`); + } + }); + } + + shouldAddLabel = true; + const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0); + const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0); + const bodyMatches = allMatches.filter(t => t.searchLocation === 'body').reduce((sum, t) => sum + t.count, 0); + const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0); + const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0); + const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0); + + reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`; + } + + core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`); + core.notice(`Reason: ${reason || 'No matching terms found'}`); + + if (shouldAddLabel) { + const existingLabels = context.payload.issue.labels.map(l => l.name); + if (!existingLabels.includes(labelName)) { + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: [labelName], + }); + core.notice(`Label "${labelName}" added. ${reason}`); + return true; + } + core.notice(`Label "${labelName}" already present.`); + return false; + } + + core.notice(`No matching terms found for label "${labelName}".`); + return false; + } + + // Process all configured labels + const processLabels = Object.entries(labelConfig) + .map(([labelName, config]) => processLabel(labelName, config)); + const labelsAdded = await Promise.all(processLabels); + const numLabelsAdded = labelsAdded.reduce((x, y) => x + y, 0); + core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`); \ No newline at end of file diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml deleted file mode 100644 index 2b1086b7faf432e91231fe0eee723c1d993c3339..0000000000000000000000000000000000000000 --- a/.github/workflows/lint-and-deploy.yaml +++ /dev/null @@ -1,89 +0,0 @@ -name: Lint and Deploy Charts - -on: pull_request - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - lint-and-deploy: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: Set up Helm - uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0 - with: - version: v3.14.4 - - #Python is required because ct lint runs Yamale and yamllint which require Python. - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: '3.13' - - - name: Set up chart-testing - uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0 - with: - version: v3.10.1 - - - name: Run chart-testing (lint) - run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm - - - name: Setup minio - run: | - docker network create vllm-net - docker run -d -p 9000:9000 --name minio --net vllm-net \ - -e "MINIO_ACCESS_KEY=minioadmin" \ - -e "MINIO_SECRET_KEY=minioadmin" \ - -v /tmp/data:/data \ - -v /tmp/config:/root/.minio \ - minio/minio server /data - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - export AWS_EC2_METADATA_DISABLED=true - mkdir opt-125m - cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. - aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket - aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive - - - name: Create kind cluster - uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0 - - - name: Build the Docker image vllm cpu - run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env . - - - name: Configuration of docker images, network and namespace for the kind cluster - run: | - docker pull amazon/aws-cli:2.6.4 - kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing - kind load docker-image vllm-cpu-env:latest --name chart-testing - docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" - kubectl create ns ns-vllm - - - name: Run chart-testing (install) - run: | - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - - - name: curl test - run: | - kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & - sleep 10 - CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ - --header "Content-Type: application/json" \ - --data '{ - "model": "opt-125m", - "prompt": "San Francisco is a", - "max_tokens": 7, - "temperature": 0 - }'):$CODE" - echo "$CODE" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index bfd02879965eee1fb1eead062edd21f53798e14f..0000000000000000000000000000000000000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,111 +0,0 @@ -# This workflow will upload a Python Package to Release asset -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions - -name: Create Release - -on: - push: - tags: - - v* - -# Needed to create release and upload assets -permissions: - contents: write - -jobs: - release: - # Retrieve tag and create release - name: Create Release - runs-on: ubuntu-latest - outputs: - upload_url: ${{ steps.create_release.outputs.upload_url }} - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Extract branch info - shell: bash - run: | - echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" - - - name: Create Release - id: create_release - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 - env: - RELEASE_TAG: ${{ env.release_tag }} - with: - github-token: "${{ secrets.GITHUB_TOKEN }}" - script: | - const script = require('.github/workflows/scripts/create_release.js') - await script(github, context, core) - - # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow. - # wheel: - # name: Build Wheel - # runs-on: ${{ matrix.os }} - # needs: release - - # strategy: - # fail-fast: false - # matrix: - # os: ['ubuntu-20.04'] - # python-version: ['3.9', '3.10', '3.11', '3.12'] - # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt. - # cuda-version: ['11.8', '12.1'] - - # steps: - # - name: Checkout - # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - # - name: Setup ccache - # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 - # with: - # create-symlink: true - # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - - # - name: Set up Linux Env - # if: ${{ runner.os == 'Linux' }} - # run: | - # bash -x .github/workflows/scripts/env.sh - - # - name: Set up Python - # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - # with: - # python-version: ${{ matrix.python-version }} - - # - name: Install CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - - # - name: Build wheel - # shell: bash - # env: - # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size - # run: | - # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) - # asset_name=${wheel_name//"linux"/"manylinux1"} - # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" - # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - - # - name: Upload Release Asset - # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # with: - # upload_url: ${{ needs.release.outputs.upload_url }} - # asset_path: ./dist/${{ env.wheel_name }} - # asset_name: ${{ env.asset_name }} - # asset_content_type: application/* - - # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested - # - name: Publish package - # uses: pypa/gh-action-pypi-publish@release/v1.8 - # with: - # repository-url: https://test.pypi.org/legacy/ - # password: ${{ secrets.PYPI_API_TOKEN }} - # skip-existing: true diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 16ae1aadb96be289c4a153dda43772e3586e84cb..1ee605dc7bb0d40e3aba622cf4f18b16c878ee08 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -12,16 +12,43 @@ jobs: uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + - '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + - 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' + - 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + - 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + - '🚀' - }) + try { + // Get the PR author + const prAuthor = context.payload.pull_request.user.login; + + // Check if this is the author's first PR in this repository + // Use GitHub's search API to find all PRs by this author + const { data: searchResults } = await github.rest.search.issuesAndPullRequests({ + q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`, + per_page: 100 + }); + + const authorPRCount = searchResults.total_count; + + console.log(`Found ${authorPRCount} PRs by ${prAuthor}`); + + // Only post comment if this is the first PR (only one PR by this author) + if (authorPRCount === 1) { + console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + + '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + + 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' + + 'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' + + 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + + 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + + 'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' + + '🚀' + }); + } else { + console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`); + } + } catch (error) { + console.error('Error checking PR history or posting comment:', error); + // Don't fail the workflow, just log the error + } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 612b290e88d46590f68ad2a79c7f92ab3ff6ff69..c16bdeeecd07a88ee9b8f9fb18155060a90aee31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: ruff-format files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos - rev: v1.34.0 + rev: v1.35.5 hooks: - id: typos - repo: https://github.com/PyCQA/isort diff --git a/CMakeLists.txt b/CMakeLists.txt index e64b7134e6fe97ebc4178f47404d5195221cc907..e669fda88123d129e427c2462c69fc803edc9cfb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") +set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx906;gfx926;gfx928;gfx936") @@ -50,7 +50,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1") +set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1") set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1") # @@ -370,9 +370,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) set(MARLIN_SRCS - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") @@ -556,6 +554,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -574,6 +573,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") @@ -765,6 +765,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "found in CUDA target architectures") endif() endif() + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + # if CUDA endif endif() @@ -806,7 +833,9 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/moe_fused_gate.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/README.md b/README.md index 7a0b9c86e777720cd5cc632531e2cb617ae2eb03..471ec0896beaa4efeca2980d5e8f52e9d896e92d 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ python3 setup.py install (若调试,可使用python3 setup.py develop) + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ ## 验证 -- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.10.1; +- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.10.2rc1; ## Known Issue - 无 diff --git a/README_ORIGIN.md b/README_ORIGIN.md index f5738866cb6bfa3527273c3abaadeafc580342c5..d08d556a3da711fa0f2330411801a41c5b203db1 100644 --- a/README_ORIGIN.md +++ b/README_ORIGIN.md @@ -18,14 +18,16 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* 🔥 +- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). +- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). -- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
Previous News +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). diff --git a/SECURITY.md b/SECURITY.md index 414669fb3712e316497349ddeaf898b6bbf17868..d6319cdb1ac27215cd0a78ed47a408867e3ef434 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -42,4 +42,9 @@ For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we ma * If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis. +* Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications + * Substantial internal deployment leveraging the upstream vLLM project. + * Established internal security teams and comprehensive compliance measures. + * Active and consistent contributions to the upstream vLLM project. + * We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included. diff --git a/benchmarks/README.md b/benchmarks/README.md index 1d715a193ea14e0880ba25580d4fa838a70b7cae..38072152b653b831375b46802fd9ce88d29b2b8a 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -32,6 +32,14 @@ become available.
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
wget http://images.cocodataset.org/zips/train2017.zip + + + ShareGPT4Video (Video) + ✅ + ✅ + + git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video + BurstGPT @@ -51,6 +59,12 @@ become available. ✅ synthetic + + RandomMultiModal (Image/Video) + 🟡 + 🚧 + synthetic + Prefix Repetition ✅ @@ -194,6 +208,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -230,6 +245,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -244,6 +260,7 @@ vllm bench serve \ ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -609,7 +626,7 @@ vllm bench serve \ --prefix-repetition-prefix-len 512 \ --prefix-repetition-suffix-len 128 \ --prefix-repetition-num-prefixes 5 \ - --prefix-repetition-output-len 128 + --prefix-repetition-output-len 128 ```
@@ -684,4 +701,102 @@ python benchmarks/benchmark_serving.py \ --endpoint /v1/chat/completion ``` +### Videos (ShareGPT4Video) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"video": 1}' \ + --allowed-local-media-path /path/to/sharegpt4video/videos +``` + +Send requests with videos: + +```bash +python benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + +### Synthetic Random Images (random-mm) + +Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Video sampling is not yet implemented. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --mm-processor-kwargs max_pixels=1003520 +``` + +Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --num-prompts 100 \ + --max-concurrency 10 \ + --random-prefix-len 25 \ + --random-input-len 300 \ + --random-output-len 40 \ + --random-range-ratio 0.2 \ + --random-mm-base-items-per-request 2 \ + --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ + --request-rate inf \ + --ignore-eos \ + --seed 42 +``` + +The number of items per request can be controlled by passing multiple image buckets: + +```bash + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0.5 \ + --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ + --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ +``` + +Flags specific to `random-mm`: + +- `--random-mm-base-items-per-request`: base number of multimodal items per request. +- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. +- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. +- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. + diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 1559ca2d92841cc33eb1921b319bce5c512f5b90..ba7c733be0b25bdef90e4fba703ecb4640fe1787 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -34,6 +34,7 @@ class RequestFuncInput: multi_modal_content: Optional[dict | list[dict]] = None ignore_eos: bool = False language: Optional[str] = None + request_id: Optional[str] = None @dataclass @@ -71,6 +72,9 @@ async def async_request_tgi( "inputs": request_func_input.prompt, "parameters": params, } + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len if request_func_input.ignore_eos: @@ -82,7 +86,9 @@ async def async_request_tgi( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -145,6 +151,9 @@ async def async_request_trt_llm( } if request_func_input.ignore_eos: payload["min_length"] = request_func_input.output_len + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -152,7 +161,9 @@ async def async_request_trt_llm( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -211,6 +222,8 @@ async def async_request_deepspeed_mii( "top_p": 1.0, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -283,6 +296,8 @@ async def async_request_openai_completions( if request_func_input.extra_body: payload.update(request_func_input.extra_body) headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -395,6 +410,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -491,6 +508,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id # Send audio file def to_bytes(y, sr): diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 572292a5aca46faad63f839a6fadf578424366ea..2ea4f9ccaff2b70d57a791de36dd9ec9ace69fe3 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -19,6 +19,7 @@ import logging import random from abc import ABC, abstractmethod from collections.abc import Mapping +from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO @@ -54,6 +55,7 @@ class SampleRequest: expected_output_len: int multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None lora_request: Optional[LoRARequest] = None + request_id: Optional[str] = None # ----------------------------------------------------------------------------- @@ -155,7 +157,10 @@ class BenchmarkDataset(ABC): @abstractmethod def sample( - self, tokenizer: PreTrainedTokenizerBase, num_requests: int + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -167,6 +172,7 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. + request_id_prefix (str) The prefix of request_id. Returns: list[SampleRequest]: A list of sample requests generated from the @@ -175,7 +181,10 @@ class BenchmarkDataset(ABC): raise NotImplementedError("sample must be implemented in subclasses.") def maybe_oversample_requests( - self, requests: list[SampleRequest], num_requests: int + self, + requests: list[SampleRequest], + num_requests: int, + request_id_prefix: str = "", ) -> None: """ Oversamples the list of requests if its size is less than the desired @@ -183,11 +192,18 @@ class BenchmarkDataset(ABC): Args: requests (List[SampleRequest]): The current list of sampled - requests. num_requests (int): The target number of requests. + requests. + num_requests (int): The target number of requests. + request_id_prefix (str) The prefix of the request ids. """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, k=num_requests - len(requests)) + additional = deepcopy( + random.choices(requests, k=num_requests - len(requests)) + ) + for i in range(len(additional)): + req = additional[i] + req.request_id = request_id_prefix + str(len(requests) + i) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]: ) +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and "bytes" in video: + video_bytes = video["bytes"] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, + } + + if isinstance(video, str): + video_url = ( + video if video.startswith(("http://", "file://")) else f"file://{video}" + ) + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- @@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset): range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: # Enforce range_ratio < 1 @@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset): prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), + request_id=request_id_prefix + str(i), ) ) + return requests @@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: samples: list = [] + ind = 0 for entry in self.data: if len(samples) >= num_requests: break @@ -430,9 +486,10 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None, ): continue - # TODO: Also support ShareGPT4Video. if image_path := entry.get("image"): mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) else: mm_content = None if enable_multimodal_chat: @@ -444,9 +501,11 @@ class ShareGPTDataset(BenchmarkDataset): expected_output_len=new_output_len, lora_request=lora_request, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(samples, num_requests) + ind += 1 + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -512,10 +571,11 @@ class CustomDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["prompt"] @@ -534,9 +594,12 @@ class CustomDataset(BenchmarkDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -578,6 +641,7 @@ class SonnetDataset(BenchmarkDataset): input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -603,6 +667,7 @@ class SonnetDataset(BenchmarkDataset): prefix_lines = self.data[:num_prefix_lines] samples = [] + ind = 0 while len(samples) < num_requests: extra_lines = random.choices( self.data, k=num_input_lines - num_prefix_lines @@ -613,14 +678,17 @@ class SonnetDataset(BenchmarkDataset): msg, add_generation_prompt=True, tokenize=False ) prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: samples.append( SampleRequest( prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 return samples @@ -672,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset): num_requests: int, max_loras: Optional[int] = None, lora_path: Optional[str] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -693,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, + request_id=request_id_prefix + str(i), ) ) return samples @@ -752,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in filtered_data: if len(sampled_requests) >= num_requests: @@ -785,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -814,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) @@ -838,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -870,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( @@ -892,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -924,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["turns"][0] @@ -947,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -974,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1000,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -1072,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset): "zed-industries/zeta": _format_zeta_prompt, } - def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + **kwargs, + ): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.dataset_path}") samples = [] - for sample in self.data: + for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) samples.append( SampleRequest( @@ -1086,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset): expected_output_len=len( tokenizer(sample["expected_output"]).input_ids ), + request_id=request_id_prefix + str(i), ) ) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests) + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -1139,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: import librosa @@ -1148,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset): prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] skipped = 0 + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: break @@ -1166,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 if skipped: logger.warning( "%d samples discarded from dataset due to" @@ -1175,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index ae38caf7290b1ce2dd9858865e42695ef3ab5452..02f5f585c0c1677db7133b994bf3f090b43448a2 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -375,11 +375,12 @@ async def benchmark( rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content = ( + prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, + request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: @@ -397,6 +398,7 @@ async def benchmark( multi_modal_content=mm_content, ignore_eos=ignore_eos, extra_body=extra_body, + request_id=request_id, ) task = limited_request_func(request_func_input=request_func_input, pbar=pbar) tasks.append(asyncio.create_task(task)) @@ -665,6 +667,7 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "sonnet": @@ -678,6 +681,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -690,6 +694,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "hf": @@ -751,6 +756,7 @@ def main(args: argparse.Namespace): num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, ) else: @@ -762,10 +768,15 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, num_requests=args.num_prompts, output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, ), "burstgpt": lambda: BurstGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path - ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -773,6 +784,7 @@ def main(args: argparse.Namespace): input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, ), } @@ -1118,6 +1130,13 @@ def create_argument_parser(): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 7dc58f48faa907c10011c07f8de1524a777e3e54..5d96bd68cf7d276c585e3daf3a9fdad59a8e6386 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -150,7 +150,6 @@ def run_vllm( end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" - prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0].expected_output_len for request in requests: @@ -653,8 +652,8 @@ def validate_args(args): # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( - "Data parallel is not supported in offline benchmark, \ - please use benchmark serving instead" + "Data parallel is not supported in offline benchmark, " + "please use benchmark serving instead" ) diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..9663503e9baa0907aec3e73ba1f1db85107a2ed0 --- /dev/null +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_block_fp8_matmul, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton as vllm_triton + +assert current_platform.is_cuda(), ( + "Only support benchmarking w8a8 block fp8 kernel on CUDA device." +) + +# DeepSeek-V3 weight shapes +DEEPSEEK_V3_SHAPES = [ + (512 + 64, 7168), + (2112, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + (18432 * 2, 7168), + (24576, 1536), + (12288, 7168), + (4096, 7168), + (7168, 2048), +] + + +def build_w8a8_block_fp8_runner(M, N, K, block_size, device): + """Build runner function for w8a8 block fp8 matmul.""" + factor_for_scale = 1e-2 + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + # Create random FP8 tensors + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + # Create scales + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) + * factor_for_scale + ) + + def run(): + return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16) + + return run + + +@vllm_triton.testing.perf_report( + vllm_triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=["torch-bf16", "w8a8-block-fp8"], + line_names=["torch-bf16", "w8a8-block-fp8"], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs W8A8 Block FP8 GEMMs", + args={}, + ) +) +def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): + M = batch_size + device = "cuda" + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + a = torch.randn((M, K), device=device, dtype=torch.bfloat16) + b = torch.randn((N, K), device=device, dtype=torch.bfloat16) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: # w8a8-block-fp8 + run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +if __name__ == "__main__": + block_size = (128, 128) + + for N, K in DEEPSEEK_V3_SHAPES: + print(f"\nBenchmarking DeepSeek-V3, N={N} K={K}") + + print(f"TFLOP/s comparison (block_size={block_size}):") + benchmark_tflops.run( + print_data=True, + # show_plots=False, + # save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}", + N=N, + K=K, + block_size=block_size, + ) + + print("\nBenchmark finished!") diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 1d4e730f99ae911535caa4cca17122656db06a4b..a6b42406b5cb06b793cb043c2d7f24fab827e078 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -80,6 +80,11 @@ def bench_run( a, score, topk, renormalize=False ) + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -111,6 +116,10 @@ def bench_run( w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, @@ -125,6 +134,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -136,6 +149,10 @@ def bench_run( w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): @@ -150,6 +167,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -194,6 +215,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, ) @@ -231,6 +256,10 @@ def bench_run( "w1_scale": w1_scale, "w2_scale": w2_scale, "per_act_token": per_act_token, + "ab_strides1": ab_strides1, + "ab_strides2": ab_strides2, + "c_strides1": c_strides1, + "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -289,6 +318,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, per_act_token, @@ -297,7 +330,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 975d10f2e92ec80ce0a6425923311c8c3a7cd9a6..1b1c3b321cce44d61ec6fb3e0f97f22950956978 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 - - if bt.w_ch_s is not None: - s_ch = bt.w_ch_s.to(torch.float32) - else: - s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) - - if bt.w_tok_s is not None: - s_tok = bt.w_tok_s.to(torch.float32) - else: - s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) - - fn = lambda: ops.marlin_qqq_gemm( - a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - ) + raise NotImplementedError("QQQ is not supported anymore") return fn @@ -305,6 +284,25 @@ def machete_create_bench_fn( ) +def cutlass_w4a8_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.cutlass_encode_and_reorder_int4b(w_q) + # expects fp8 scales + w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) + + return lambda: ops.cutlass_w4a8_mm( + a=bt.a, + b_q=w_q, + b_group_scales=w_s, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + maybe_schedule=schedule, + ) + + # impl # bench @@ -406,6 +404,20 @@ def bench( ) ) + # cutlass w4a8 + if types.act_type == torch.float8_e4m3fn and group_size == 128: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass w4a8 ({name_type_string})", + [ + cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index dbe03002fcebd41898a018851df530c8ad152f1e..a5073a0f4e331c44440f2982f87437f9b58d7b59 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -489,8 +489,10 @@ class BenchmarkWorker: ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. + block_n = block_quant_shape[0] if block_quant_shape else None + block_k = block_quant_shape[1] if block_quant_shape else None op_config = get_moe_configs( - num_experts, shard_intermediate_size // 2, dtype_str, use_nn_moe=nn_moe + num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k, use_nn_moe=nn_moe ) if op_config is None: config = get_default_config( @@ -500,8 +502,8 @@ class BenchmarkWorker: hidden_size, topk, dtype_str, - is_marlin=False, - use_nn_moe=nn_moe + block_quant_shape, + use_nn_moe=nn_moe, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..0650cbf3cc18e1c533c0bb9da6ad5914e5a49979 --- /dev/null +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm, +) +from vllm.platforms import current_platform + + +def benchmark(E, T, H, G=128, runs=50): + current_platform.seed_everything(42) + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") + tokens_per_expert = torch.randint( + T // 2, T, size=(E,), dtype=torch.int32, device="cuda" + ) + + # Warmup + for _ in range(10): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + # Benchmark + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(runs): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + avg_time = (time.perf_counter() - start) / runs * 1000 + + # Calculate actual work done (only count valid tokens) + actual_tokens = tokens_per_expert.sum().item() + actual_elements = actual_tokens * H + + # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops + ops_per_element = 8 + total_ops = actual_elements * ops_per_element + gflops = total_ops / (avg_time / 1000) / 1e9 + + # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) + input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs + output_bytes = actual_tokens * H * 1 # H fp8 outputs + scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 + total_bytes = input_bytes + output_bytes + scale_bytes + memory_bw = total_bytes / (avg_time / 1000) / 1e9 + + return avg_time, gflops, memory_bw + + +configs = [ + (8, 32, 1024), + (16, 64, 2048), + (32, 128, 4096), + # DeepSeekV3 Configs + (256, 16, 7168), + (256, 32, 7168), + (256, 64, 7168), + (256, 128, 7168), + (256, 256, 7168), + (256, 512, 7168), + (256, 1024, 7168), +] + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") +print("-" * 50) + +for E, T, H in configs: + try: + time_ms, gflops, gbps = benchmark(E, T, H) + print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") + except Exception: + print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 77136edca45b5031a235698c0a392a82f1a9ed5f..603ce5ecf0d2c60ca4de585605fbe89f9529f61f 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -3,16 +3,17 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +from vllm.utils import round_up -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,149 +27,188 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_decode( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") - device = "cuda" torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + sm_scale = float(1.0 / (head_size**0.5)) - # For decode, batch_size is num_decode_token - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) - kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) - max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / block_size) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query + + kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_seq_len + + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() + + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 ) - - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) - k_scale = v_scale = 1.0 - - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - - # Benchmark TRT decode - def trt_decode(): - return flashinfer.decode.trtllm_batch_decode_with_kv_cache( - q, - kv_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) - - def time_fn(fn, warmup=10, trials=20): - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - times = [] - for i in range(warmup): - fn() - for i in range(trials): - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) # ms - return sum(times) / len(times), torch.std(torch.tensor(times)) - - # TRT Decode - trt_mean, trt_std = time_fn(trt_decode) - kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] + for i in range(batch_size): + seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, - use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), + use_tensor_cores=True, ) - wrapper.plan( kv_indptr, kv_indices, kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, "NONE", + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, + kv_data_type=dtype, ) + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + def baseline_decode(): - return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) + + def trtllm_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, + out=output_trtllm, + ) baseline_mean, baseline_std = time_fn(baseline_decode) + trtllm_mean, trtllm_std = time_fn(trtllm_decode) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -180,17 +220,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -209,45 +250,43 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="fp8", - ) - all_results.append(result) + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_decode( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 67bd9aebbcca99023ecb575e4f45e18039b5a899..40903c6c3444f5aa3aa88e2247ad3c9c643c2cbb 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -3,16 +3,17 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +from vllm.utils import round_up -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,84 +27,100 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_prefill( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + max_q_len = max_kv_len = max_seq_len - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) + sm_scale = float(1.0 / (head_size**0.5)) - q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - q_lens[-1] = MAX_SEQ_LEN - max_q_len = max(q_lens) + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / block_size) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len q_indptr = torch.cat( [ torch.tensor([0], dtype=torch.int32), - torch.cumsum( - torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), ] ) - q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) - - kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] - kv_lens[-1] = MAX_SEQ_LEN - - seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) - max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size - block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn( + torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype ) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) - k_scale = v_scale = 1.0 + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_kv_len - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() - output_trtllm = torch.empty(q.shape, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout @@ -115,12 +132,12 @@ def benchmark_prefill( kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, causal=True, sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache.dtype, + kv_data_type=dtype, ) def time_fn(fn, warmup=10, trials=20): @@ -138,52 +155,76 @@ def benchmark_prefill( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + def baseline_prefill(): return wrapper.run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, ) - def trt_prefill(): + def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=seq_lens_tensor, + seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) - trt_mean, trt_std = time_fn(trt_prefill) baseline_mean, baseline_std = time_fn(baseline_prefill) + trtllm_mean, trtllm_std = time_fn(trtllm_prefill) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" - f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" + f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -195,17 +236,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -224,27 +266,42 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_prefill( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_prefill( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 4fcdbadd65ecd366f51ce5d9053cc700854dcb71..98bde9d83c82d199d3eb3fb609b10fff28b99525 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -11,8 +11,8 @@ from datetime import datetime from typing import Any import torch -import tqdm import triton +from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( _w8a8_block_fp8_matmul, @@ -141,6 +141,7 @@ def get_weight_shapes(tp_size): # cannot TP total = [ (512 + 64, 7168), + (2112, 7168), ((128 + 64) * 128, 7168), (128 * (128 + 128), 512), (7168, 16384), diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394afbdfa4f3a352a599db349a2601bea5..9a057990bda5f64deada11b0beb56c0207570de5 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ WEIGHT_SHAPES = { ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 73728], 1), + ([36864, 12288], 0), + ], } diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md index ae0866ae607511023596b0d192d4e311ae95d42e..7adf97bcf56228916918fa54786a03e28b4abdbe 100644 --- a/benchmarks/multi_turn/README.md +++ b/benchmarks/multi_turn/README.md @@ -5,11 +5,13 @@ The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `re First start serving your model ```bash -export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ -vllm serve $MODEL_NAME --disable-log-requests +vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests ``` +The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface). + ## Synthetic Multi-Turn Conversations Download the following text file (used for generation of synthetic conversations) @@ -26,10 +28,10 @@ But you may use other text files if you prefer (using this specific file is not Then run the benchmarking script ```bash -export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ -python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \ ---num-clients 2 --max-active-conversations 6 +python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \ +--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6 ``` You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.). diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 53c3207491d188ddbd1f2bafd7b0f554aefcaa05..d23b7b6e4571dff6263aeb508ba5c8e4604ed786 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -825,9 +825,11 @@ def get_client_config( # Arguments for API requests chat_url = f"{args.url}/v1/chat/completions" + model_name = args.served_model_name if args.served_model_name else args.model + req_args = RequestArgs( chat_url=chat_url, - model=args.model, + model=model_name, stream=not args.no_stream, limit_min_tokens=args.limit_min_tokens, limit_max_tokens=args.limit_max_tokens, @@ -1247,9 +1249,19 @@ async def main() -> None: default=0, help="Seed for random number generators (default: 0)", ) + parser.add_argument( "-m", "--model", type=str, required=True, help="Path of the LLM model" ) + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + parser.add_argument( "-u", "--url", diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index e0da46e2accaaea3dff79fd7e1acf57bd2b0cc84..52bfd82c7fcfed310bd925ecdbd59ba1f1140f40 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,6 +1,7 @@ include(FetchContent) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -182,17 +183,17 @@ endif() # # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) # Flag to enable ACL kernels for AARCH64 platforms -if ( VLLM_BUILD_ACL STREQUAL "ON") +if (VLLM_BUILD_ACL STREQUAL "ON") set(USE_ACL ON) else() set(USE_ACL OFF) endif() -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.8.1 + GIT_TAG v3.9 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) @@ -204,7 +205,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") - endif() + endif() set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") @@ -217,38 +218,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) -elseif(POWER10_FOUND) - FetchContent_Declare( - oneDNN - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.7.2 - GIT_PROGRESS TRUE - GIT_SHALLOW TRUE + add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") + target_include_directories( + dnnl_ext + PUBLIC ${oneDNN_SOURCE_DIR}/include + PUBLIC ${oneDNN_BINARY_DIR}/include + PRIVATE ${oneDNN_SOURCE_DIR}/src ) - - set(ONEDNN_LIBRARY_TYPE "STATIC") - set(ONEDNN_BUILD_DOC "OFF") - set(ONEDNN_BUILD_EXAMPLES "OFF") - set(ONEDNN_BUILD_TESTS "OFF") - set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") - set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") - set(ONEDNN_BUILD_GRAPH "OFF") - set(ONEDNN_ENABLE_JIT_PROFILING "OFF") - set(ONEDNN_ENABLE_ITT_TASKS "OFF") - set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") - set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) - - set(DNNL_CPU_RUNTIME "OMP") - - FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) + target_link_libraries(dnnl_ext dnnl) + target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) + list(APPEND LIBS dnnl_ext) + set(USE_ONEDNN ON) +else() + set(USE_ONEDNN OFF) endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") @@ -275,7 +261,6 @@ set(VLLM_EXT_SRC if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) @@ -289,14 +274,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ${VLLM_EXT_SRC}) add_compile_definitions(-DCPU_CAPABILITY_AVX512) endif() -elseif(POWER10_FOUND) - set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" - ${VLLM_EXT_SRC}) endif() -if (ASIMD_FOUND) + +if(USE_ONEDNN) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" + "csrc/cpu/dnnl_kernels.cpp" ${VLLM_EXT_SRC}) endif() diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index ee6768bce26ca48714361176e502a2d3f8829669..02224cfe3ee81289385f335791794186ac33ff44 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 + GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc/include) + ${flashmla_SOURCE_DIR}/csrc) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index e0e95d06290dfd6e3df041f9b3d3ac8cb4220db3..6dd6f269f3dc955558914fe9387c0cbe30d22f32 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options( // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will // perform worse with larger context length and smaller batch sizes. - num_kv_splits, // split_kv + static_cast(num_kv_splits), // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute @@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba // Assumes device 0 when getting sm_count. arguments.hw_info.sm_count = sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; - arguments.split_kv = num_kv_splits; + arguments.split_kv = static_cast(num_kv_splits); MlaSm100Type::Fmha::set_split_kv(arguments); return MlaSm100Type::Fmha::get_workspace_size(arguments); diff --git a/csrc/cache.h b/csrc/cache.h index 02049d3582f0761bf3268441d54978473d8f0151..31ec03ca49f4b4d57752b6e37ab3dffb6e615040 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -42,10 +42,35 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, const std::string& kv_cache_dtype, torch::Tensor& scale); +void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, + torch::Tensor& cp_local_token_select_indices, + torch::Tensor& kv_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& scale); + // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); + +void gather_and_maybe_dequant_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt); + +// TODO(hc): cp_gather_cache need support scaled kvcahe in the future. +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); + void read_cache( torch::Tensor& keys, torch::Tensor& values, @@ -61,10 +86,3 @@ void write_cache_multi_layers( std::vector const& value_caches, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); - -void gather_cache( - torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] - torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] - torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] - torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, std::optional seq_starts = std::nullopt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3132affeb4b5e2db1fdc36efa8cdc1795c87dc4d..89ac4761be44b177604cd1e8b2834fb8965d0e99 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -592,6 +593,51 @@ __global__ void concat_and_cache_mla_kernel( copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } +template +__global__ void cp_fused_concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim] + const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = cp_local_token_select_indices[blockIdx.x]; + const int64_t slot_idx = slot_mapping[blockIdx.x]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, + int src_stride, int dst_stride, int size, int offset) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = + fp8::scaled_convert(src[src_idx], *scale); + } + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + } // namespace vllm // KV_T is the data type of key and value tensors. @@ -896,6 +942,20 @@ void write_cache_multi_layers( kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::cp_fused_concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + cp_local_token_select_indices.data_ptr(), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + void concat_and_cache_mla( torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] @@ -934,6 +994,50 @@ void concat_and_cache_mla( CALL_CONCAT_AND_CACHE_MLA); } +// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel +// calls into one: +// k_c_normed.index_select(0, cp_local_token_select_indices) + \ +// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \ +// concat_and_cache_mla. +void cp_fused_concat_and_cache_mla( + torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_total_tokens, pe_dim] + torch::Tensor& cp_local_token_select_indices, // [num_tokens] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CP_FUSED_CONCAT_AND_CACHE_MLA); +} + namespace vllm { template @@ -1012,9 +1116,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template -__global__ void gather_cache( - const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, +template +__global__ void gather_and_maybe_dequant_cache( + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, // ENTRIES...] scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] @@ -1022,6 +1126,7 @@ __global__ void gather_cache( const int32_t block_size, const int32_t entry_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch @@ -1063,10 +1168,16 @@ __global__ void gather_cache( if (partial_block_size) full_blocks_end -= 1; } - auto copy_entry = [&](const scalar_t* __restrict__ _src, + auto copy_entry = [&](const cache_t* __restrict__ _src, scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) - _dst[i] = _src[i]; + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + _dst[i] = static_cast(_src[i]); + } else { + _dst[i] = + fp8::scaled_convert(_src[i], *scale); + } + } }; for (int pid = split_start; pid < full_blocks_end; ++pid) { @@ -1093,8 +1204,144 @@ __global__ void gather_cache( } // namespace vllm // Macro to dispatch the kernel based on the data type. -#define CALL_GATHER_CACHE(CPY_DTYPE) \ - vllm::gather_cache<<>>( \ +// SCALAR_T is the data type of the destination tensor. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, \ + reinterpret_cast(scale.data_ptr()), seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) +void gather_and_maybe_dequant_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); +} + +namespace vllm { +template +// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by +// block_size. +__global__ void cp_gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRY_SIZE] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts // Optional: starting offsets per + // batch +) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on it + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + if (seq_starts != nullptr) { + offset += seq_starts[bid]; + } + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * dst_entry_stride; + copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr); + offset += 1; + // bump to next block + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \ + vllm::cp_gather_cache<<>>( \ reinterpret_cast(src_cache.data_ptr()), \ reinterpret_cast(dst.data_ptr()), \ block_table.data_ptr(), cu_seq_lens.data_ptr(), \ @@ -1104,9 +1351,9 @@ __global__ void gather_cache( // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence -// - Optionally, seq_starts (if provided) offsets the starting block index by -// (seq_starts[bid] / page_size) -void gather_cache( +// - Optionally, seq_starts (if provided) offsets the starting slot index by +// seq_starts[bid] +void cp_gather_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] @@ -1157,11 +1404,11 @@ void gather_cache( seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; if (dtype_bits == 32) { - CALL_GATHER_CACHE(uint32_t); + CALL_CP_GATHER_CACHE(uint32_t); } else if (dtype_bits == 16) { - CALL_GATHER_CACHE(uint16_t); + CALL_CP_GATHER_CACHE(uint16_t); } else if (dtype_bits == 8) { - CALL_GATHER_CACHE(uint8_t); + CALL_CP_GATHER_CACHE(uint8_t); } else { TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); } diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 3952c43cbc727de4dcdb2de2fa447d837742d123..982f7c07a13bd33a1deab10786e58cae527d5c0d 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec { explicit FP16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec { (__m128i)vec8_data.reg, 1)) {} void save(void* ptr) const { - *reinterpret_cast<__m256i*>(ptr) = reg_low; - *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; + _mm256_storeu_si256((__m256i*)ptr, reg_low); + _mm256_storeu_si256((__m256i*)ptr + 1, reg_high); } }; #endif diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3f00edb36068aee040927c49c4cf3087cead3aa --- /dev/null +++ b/csrc/cpu/dnnl_helper.cpp @@ -0,0 +1,346 @@ +#include +#include + +#include "common/memory_desc.hpp" +#include "common/memory.hpp" + +#include "dnnl_helper.h" + +static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; +} + +static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; +} + +void release_dnnl_matmul_handler(int64_t handler) { + DNNLMatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + delete ptr; +} + +template +class DNNLPrimitiveCache { + public: + using cache_value_t = std::pair; + using result_value_t = VT; + using container_t = std::list; + using value_iterator_t = typename container_t::iterator; + using map_t = std::unordered_map; + using creator_t = VT (*)(); + + public: + DNNLPrimitiveCache(size_t capacity) + : capacity_(capacity), + values_(), + key_to_value_(std::min(256lu, capacity)) { + assert(capacity > 0); + } + + template + result_value_t get_or_create(const KT& key, F&& creator) { + std::optional value = get_value(key); + if (value.has_value()) { + return value.value()->second; + } else { + return add_value({key, creator()})->second; + } + } + + size_t size() const { return values_.size(); } + + private: + void dump_data() { + std::stringstream ss; + ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec + << "\n"; + ss << "container: ["; + for (auto&& iter : values_) { + ss << "(" << iter.first << ", " << std::hex + << reinterpret_cast(iter.second.get()) << "), " << std::dec; + } + ss << "]\n"; + + ss << "map: ["; + for (auto&& iter : key_to_value_) { + ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex + << reinterpret_cast(iter.second->second.get()) << std::dec + << "), "; + } + ss << "]\n"; + std::printf("%s\n", ss.str().c_str()); + } + + value_iterator_t add_value(cache_value_t&& new_value) { + if (size() == capacity_) { + cache_value_t& last_item = values_.back(); + key_to_value_.erase(last_item.first); + values_.pop_back(); + } + + auto& added_value_ = values_.emplace_front(std::move(new_value)); + key_to_value_.emplace(added_value_.first, values_.begin()); + return values_.begin(); + } + + std::optional get_value(const KT& key) { + if (key_to_value_.size() > 0 && key == values_.begin()->first) { + return values_.begin(); + } + + auto value_map_iterator = key_to_value_.find(key); + if (value_map_iterator != key_to_value_.end()) { + values_.splice(values_.begin(), values_, value_map_iterator->second); + return value_map_iterator->second; + } else { + return {}; + } + } + + private: + const size_t capacity_; + container_t values_; + map_t key_to_value_; +}; + +DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( + const Args& args, dnnl::memory::data_type b_type) + : b_n_size_(args.b_n_size), + b_n_stride_(args.b_n_stride), + b_k_size_(args.b_k_size), + b_k_stride_(args.b_k_stride), + b_type_(b_type), + c_type_(args.c_type), + runtime_memory_ptrs_(8), + primitive_cache_size_(args.primitive_cache_size) { + assert(primitive_cache_size_ > 0); +} + +void DNNLMatMulPrimitiveHandler::prepack_weight( + void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); + dnnl::memory packed_weight(b_target_mem_desc, default_engine()); + { + dnnl::reorder(original_weight, packed_weight) + .execute(default_stream(), original_weight, packed_weight); + default_stream().wait(); + } + memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight; + b_target_mem_desc_ = b_target_mem_desc; +} + +void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr( + size_t index, dnnl_memory* memory_ptr) { + dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage(); + dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md()); + runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc}; +} + +std::pair +DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) { + return runtime_memory_ptrs_[index]; +} + +namespace std { +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.a_qs)) ^ + hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^ + hash()(static_cast(val.c_type)); + } +}; + +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; +} // namespace std + +bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp && + l.c_type == r.c_type; +} + +bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size && + l.bias_type == r.bias_type; +} + +static std::shared_ptr +get_w8a8_class_primitive_cache( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), + dnnl::memory::data_type::s8), + use_azp_(args.use_a_zero_point), + a_qs_(args.a_quantization_strategy), + b_qs_(args.b_quantization_strategy), + m_size_cache_(nullptr) { + assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL); + assert(b_qs_ != QuantizationStrategy::PER_TOKEN); + if (a_qs_ == QuantizationStrategy::PER_TOKEN) { + assert(!use_azp_); + }; + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2); + a_scale_storage->set_data_handle((void*)args.a_scales_ptr); + } + if (use_azp_) { + auto&& [a_zero_point_storage, a_zero_point_mem_desc] = + get_runtime_memory_ptr(3); + a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr); + } + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, + .b_k_size = b_k_size_, + .a_qs = a_qs_, + .b_qs = b_qs_, + .use_azp = use_azp_, + .c_type = c_type_}; + m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_); + } + + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + return dnnl::matmul(desc); + }); +} + +void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get()); + if (use_azp_) { + memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get()); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), + (void*)args.b_scales_ptr); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), (void*)args.b_scales_ptr); + } + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); +} + +dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md({key.a_m_size, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab); + dnnl::memory::desc b_md; + if (first_time) { + b_md = + dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8, + dnnl::memory::format_tag::any); + } else { + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_SRC, 0); + if (use_azp_) { + attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + if (key.use_bias) { + // For PER_TOKEN, bias will be applied in epilogue + assert(a_qs_ == QuantizationStrategy::PER_TENSOR); + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..54ceefced9e985e4b3cd20a1df17e1cbc215d57e --- /dev/null +++ b/csrc/cpu/dnnl_helper.h @@ -0,0 +1,169 @@ +#ifndef DNNL_HELPER_H +#define DNNL_HELPER_H + +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace c10 { +struct BFloat16; +struct Half; +} // namespace c10 + +namespace dnnl { +namespace impl { +struct memory_storage_t; +struct matmul_pd_t; +struct matmul_desc_t; +} // namespace impl +} // namespace dnnl +struct dnnl_memory_desc; + +template +class DNNLPrimitiveCache; + +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} + +class DNNLMatMulPrimitiveHandler { + public: + virtual ~DNNLMatMulPrimitiveHandler() = default; + + protected: + struct Args { + dnnl_dim_t b_n_size; + dnnl_dim_t b_n_stride; + dnnl_dim_t b_k_size; + dnnl_dim_t b_k_stride; + void* b_ptr; + dnnl::memory::data_type c_type; + size_t primitive_cache_size; + }; + + protected: + DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); + + void prepack_weight(void* original_b_ptr, + dnnl::memory::desc b_target_mem_desc); + + void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); + + std::pair + get_runtime_memory_ptr(size_t index); + + protected: + const dnnl_dim_t b_n_size_; + const dnnl_dim_t b_n_stride_; + const dnnl_dim_t b_k_size_; + const dnnl_dim_t b_k_stride_; + dnnl::memory::data_type b_type_; + dnnl::memory::data_type c_type_; + std::unordered_map memory_cache_; + std::vector> + runtime_memory_ptrs_; + dnnl::memory::desc b_target_mem_desc_; + int64_t primitive_cache_size_; +}; + +class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL }; + + struct Args : public DNNLMatMulPrimitiveHandler::Args { + bool use_a_zero_point; + QuantizationStrategy a_quantization_strategy; + QuantizationStrategy b_quantization_strategy; + float* b_scales_ptr; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + QuantizationStrategy a_qs; + QuantizationStrategy b_qs; + bool use_azp; + dnnl::memory::data_type c_type; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const int8_t* a_ptr; + const float* a_scales_ptr; + const int32_t* a_zero_points_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + W8A8MatMulPrimitiveHandler(const Args& args); + + QuantizationStrategy get_input_scale_strategy() const { return a_qs_; } + + bool get_input_use_zero_point() const { return use_azp_; } + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + const bool use_azp_; + const QuantizationStrategy a_qs_; + const QuantizationStrategy b_qs_; + std::shared_ptr m_size_cache_; +}; + +#endif diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp deleted file mode 100644 index 1cb8dc5b25a66e7d677d46206a28ec158c6abe2c..0000000000000000000000000000000000000000 --- a/csrc/cpu/dnnl_helper.hpp +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef DNNL_HELPER_HPP -#define DNNL_HELPER_HPP - -#include -#include - -#include "oneapi/dnnl/dnnl.hpp" - -namespace { -template -struct DNNLType { - static constexpr dnnl::memory::data_type type = - dnnl::memory::data_type::undef; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; -}; - -template -constexpr inline dnnl::memory::data_type get_dnnl_type() { - return DNNLType>::type; -} -}; // namespace - -template -class DNNLPrimitiveHelper { - public: - // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) - // A: [M, K], row-major - // B: [K, N], column-major - // C: [M, N], row-major - // bias: [N], row-major, optional - // a_scales: [MS] - // b_scales: [NS] - // Note: Due to the limitation of oneDNN - // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is - // not supported. - - template - static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, - const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, - dnnl_dim_t K, const float* a_scales, - const float* b_scales, dnnl_dim_t MS, - dnnl_dim_t NS) { - auto&& OutputType = get_dnnl_type(); - auto&& BiasType = get_dnnl_type(); - - dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); - dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); - dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); - - dnnl::primitive_attr attr; - if constexpr (!InputNoScale) { - if (MS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_SRC, 0); - } else { - // per-token - TORCH_CHECK(false, "per-token quantization is unsupported."); - } - } - - if (NS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); - } else { - // per-channel - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); - } - - dnnl::matmul::primitive_desc matmul_pd; -// Create memory descriptors with format_tag::any for the primitive. This -// enables the matmul primitive to choose memory layouts for an -// optimized primitive implementation, and these layouts may differ from the -// ones provided by the user. -#ifdef __aarch64__ - auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, - dnnl::memory::format_tag::any); - auto mat_weights_md = dnnl::memory::desc( - {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); - auto mat_dst_md = - dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, - mat_weights_md, bias_md, - mat_dst_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc( - default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); - } -#else - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - bias_md, c_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - c_md, attr); - } -#endif - dnnl::matmul matmul(matmul_pd); - - auto& engine = default_engine(); - - dnnl::memory a_m(a_md, engine, (void*)a); - dnnl::memory b_m(b_md, engine, (void*)b); - dnnl::memory c_m(c_md, engine, (void*)c); - dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)a_scales); - dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)b_scales); - - auto& stream = default_stream(); - - auto mat_src_mem = a_m; - auto mat_weights_mem = b_m; - auto mat_dst_mem = c_m; -#ifdef __aarch64__ - if (matmul_pd.weights_desc() != b_m.get_desc()) { - mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); - dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); - } -#endif - if constexpr (InputNoScale) { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } else { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } - stream.wait(); - } - - private: - static dnnl::engine& default_engine() { - static dnnl::engine engine(dnnl::engine::kind::cpu, 0); - return engine; - } - - static dnnl::stream& default_stream() { - static dnnl::stream stream(default_engine()); - return stream; - } -}; -#endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..acc3b9ecde143382728a9a709081971de30ad3f5 --- /dev/null +++ b/csrc/cpu/dnnl_kernels.cpp @@ -0,0 +1,494 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.h" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; +#endif + +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power architecture-specific vector type + using load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures + using load_vec_type = vec_op::FP16Vec16; +#endif + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = azp_val; + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const int32_t* azp, + const float* azp_adj, const scalar_t* bias, + const int64_t num_tokens, + const int64_t hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + const int64_t thread_num = omp_get_max_threads(); + if (num_tokens > thread_num) { +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + const float* input_ptr = input + i * hidden_size; + scalar_t* output_ptr = output + i * hidden_size; + int64_t j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + for (; j < hidden_size - vec_elem_num; ++j) { + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j); + } + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j, hidden_size - j); + } + } else { + const int64_t vec_iteration = + (hidden_size + vec_elem_num - 1) / vec_elem_num; + const int64_t vec_iteration_per_thread = + (vec_iteration + thread_num - 1) / thread_num; + const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num; +#pragma omp parallel for schedule(static, 1) + for (int64_t i = 0; i < thread_num; ++i) { + const int64_t start = elem_num_per_thread * i; + const int64_t end = std::min(hidden_size, elem_num_per_thread + start); + for (int64_t j = 0; j < num_tokens; ++j) { + cvt_vec_t token_scale_vec(a_scale[j]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[j] * static_cast(azp[j]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + int64_t k = start; + const float* input_ptr = input + j * hidden_size; + scalar_t* output_ptr = output + j * hidden_size; + for (; k < end - vec_elem_num; k += vec_elem_num) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k); + } + if (k < end) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k, end - k); + } + } + } + } +} +} // namespace + +int64_t create_onednn_scaled_mm_handler( + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& b_scales, // [1] or [OC] + at::ScalarType output_type, bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(b_scales.is_contiguous()); + + W8A8MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + if (b_scales.numel() == 1) { + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + } else { + TORCH_CHECK_EQ(b_scales.numel(), b.size(1)); + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL; + } + args.b_scales_ptr = b_scales.data_ptr(); + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + if (dynamic_act_quant) { + // dynamic per-token, bias, A scales and A zps will be applied in outside. + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN; + args.use_a_zero_point = false; + } else { + // static per-tensor + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + args.use_a_zero_point = use_azp; + } + + VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler", + [&] { + if (dynamic_act_quant) { + args.c_type = get_dnnl_type(); + } else { + args.c_type = get_dnnl_type(); + } + }); + + return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args)); +} + +void onednn_scaled_mm( + torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& a_scales, // [M] or [1] + const std::optional& azp, // [M] or [1] + const std::optional& azp_adj, // [M] or [1] + const std::optional& bias, // [N] + int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_scaled_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + W8A8MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + const int32_t* azp_ptr = nullptr; + if (azp.has_value()) { + azp_ptr = azp->data_ptr(); + } + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + TORCH_CHECK_EQ(a_scales.numel(), 1); + } + + W8A8MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_ptr = a.data_ptr(); + exec_args.a_m_size = a.size(0); + exec_args.bias_ptr = nullptr; + exec_args.use_bias = false; + exec_args.a_scales_ptr = nullptr; + exec_args.a_zero_points_ptr = nullptr; + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] { + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + if (bias.has_value()) { + exec_args.bias_ptr = bias->data_ptr(); + exec_args.bias_type = get_dnnl_type(); + exec_args.use_bias = true; + } + exec_args.a_scales_ptr = a_scales.data_ptr(); + exec_args.a_zero_points_ptr = azp_ptr; + exec_args.c_ptr = c.data_ptr(); + ptr->execute(exec_args); + } else if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) { + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + exec_args.c_ptr = tmp_fp32_out.data_ptr(); + ptr->execute(exec_args); + if (bias.has_value()) { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + (scalar_t*)nullptr, c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr, + c.size(0), c.size(1)); + } + } + } else { + TORCH_CHECK(false, "invalid act quant type."); + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + const torch::Tensor& scale, std::optional const& azp) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); + + const int64_t stride = input.stride(0); + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + static_scaled_int8_quant_impl(input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), nullptr, + num_tokens, stride, hidden_size); + } + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + torch::Tensor& scale, // [batch, 1] + std::optional const& azp) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + const int64_t stride = input.stride(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, stride, + hidden_size); + } + }); +} diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp deleted file mode 100644 index 6e120b8d20a7ee1e1f9ca8018281c1b6fdf68339..0000000000000000000000000000000000000000 --- a/csrc/cpu/quant.cpp +++ /dev/null @@ -1,951 +0,0 @@ -#include "cpu_types.hpp" -#include "dnnl_helper.hpp" - -namespace { -template -struct KernelVecType { - using load_vec_type = void; - using azp_adj_load_vec_type = void; - using cvt_vec_type = void; -}; - -template <> -struct KernelVecType { - using load_vec_type = vec_op::FP32Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) -template <> -struct KernelVecType { - using load_vec_type = vec_op::BF16Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; -#endif - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power architecture-specific vector type - using load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures - using load_vec_type = vec_op::FP16Vec16; -#endif - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if defined(__AVX512F__) || defined(__aarch64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#elif defined(__powerpc64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#else -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " - "support.") -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "dynamic_scaled_int8_quant_impl requires " - "AVX512/powerpc64/AArch64 support.") -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_with_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, - "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} -#endif -} // namespace - -void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Ideally we want to fuse the GEMM and the scale procedure with oneDNN - // JIT, the intermediate data is cached in registers or L1. But for now - // the oneDNN GEMM code generation only supports two quantization - // patterns: per-tensor or per-output-channel of weight. - // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * - // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN - // GEMM, then the per-token scale (and bias) is applied with the epilogue - // C=s_a * C_inter + bias. - torch::Tensor tmp_fp32_out = - torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - bias->data_ptr(), a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } else { - // Compute C=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - nullptr, a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - } - }); -} - -void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const torch::Tensor& azp_adj, // [OC] - const std::optional& azp, // [1] or [M] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_azp only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); - } - if (azp) { - TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); - } - TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); - - // azp & bias types - TORCH_CHECK(azp_adj.dtype() == torch::kInt32); - TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); - TORCH_CHECK(!bias || bias->dtype() == c.dtype(), - "currently bias dtype must match output dtype ", c.dtype()); - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } - } else { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C_inter=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), bias->data_ptr(), - a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), - b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); - } else { - // Compute C_inter=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - - // Compute C=C_inter - s_a * s_b * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } else { - // Per-Tensor - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } - } - }); -} - -// static-per-tensor quantization. -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale, - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp.has_value() || azp->numel() == 1); - - const int hidden_size = input.size(-1); - const int num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "static_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -// dynamic-per-token quantization. -void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale, // [..., 1] - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_ppc64le only supports INT8 inputs."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - // We dont need this - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - }); -} - -#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b20a054648428e191a328cf1180b7cf586ea00a0..c9f426bdf618ab457c922ad216ba6dee42ef93e5 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -6,25 +6,20 @@ std::string init_cpu_threads_env(const std::string& cpu_ids); -void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); - -void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const torch::Tensor& azp_adj, - const std::optional& azp, - const std::optional& bias); - -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); -#endif +void release_dnnl_matmul_handler(int64_t handler); + +int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b, + const torch::Tensor& b_scales, + at::ScalarType output_type, + bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size); + +void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& a_scales, + const std::optional& azp, + const std::optional& azp_adj, + const std::optional& bias, + int64_t handler); void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, @@ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); // Quantization -#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) +#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \ + defined(__powerpc64__) at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // Helper function to release oneDNN handlers + ops.def("release_dnnl_matmul_handler(int handler) -> ()", + &release_dnnl_matmul_handler); + + // Create oneDNN W8A8 handler + ops.def( + "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " + "output_type, bool dynamic_act_quant, bool use_azp, int " + "primitive_cache_size) -> int", + &create_onednn_scaled_mm_handler); + + // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization + ops.def( + "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " + "Tensor? azp_adj, Tensor? bias, int handler) -> ()"); + ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); // Compute int8 quantized tensor for given scaling factor. ops.def( @@ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); -#elif defined(__powerpc64__) - // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); - ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); - - // Compute int8 quantized tensor and scaling factor - ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); - ops.impl("dynamic_scaled_int8_quant", torch::kCPU, - &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index f7b75c48373f68e9025020eea507415fb9405e2e..2728aa81f0c9f51906acb8c7673e55ccc55e9e43 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -19,6 +19,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) + // ROCm devices might use either fn or fnuz, so set up dispatch table for both. // A host-based check at runtime will create a preferred FP8 type for ROCm // such that the correct kernel is dispatched. @@ -45,6 +52,15 @@ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) +#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__) + +#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \ + AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..78f7b3cc1aa2502aa1d7a846b37b41e0023b275e --- /dev/null +++ b/csrc/moe/grouped_topk_kernels.cu @@ -0,0 +1,757 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +namespace cg = cooperative_groups; + +namespace vllm { +namespace moe { + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t BLOCK_SIZE = 512; +constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +namespace warp_topk { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) { + return 0; + } + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + +template +int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { + int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; + int64_t n = std::max(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); + return max(cache_topk, + round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); +} + +template +struct BitonicMerge { + // input should be a bitonic sequence, and sort it to be a monotonic sequence + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + + // ascending doesn't matter before merging since all we need is a bitonic + // sequence + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + + T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + // load and merge k sorted values + __device__ void load_sorted(T const* __restrict__ in, + idxT const* __restrict__ in_idx, idxT start) { + idxT idx = start + WARP_SIZE - 1 - lane_; + for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { + if (idx < start + k_) { + T t = in[idx]; + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + } + + __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out[out_i] = val_arr_[i]; + out_idx[out_i] = idx_arr_[i]; + } + } + } + + __device__ void dumpIdx(idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out_idx[out_i] = idx_arr_[i]; + } + } + } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + + int const lane_; + idxT const k_; + T const dummy_; + +}; // end class WarpSort + +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; + + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T const* in, idxT start, idxT end) { + idxT const end_for_fullwarp = + round_up_to_multiple_of(end - start) + start; + for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { + T val = (i < end) ? in[i] : dummy_; + add(val, i); + } + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); + if (mask == 0) { + return; + } + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + + // after done(), smem is used for merging results among warps + __syncthreads(); + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + + T& old = val_arr_[max_arr_len_ - 1]; + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; // end class WarpSelect +} // namespace warp_topk + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ void topk_with_k2(T* output, T const* input, + cg::thread_block_tile<32> const& tile, + int32_t const lane_id, + int const num_experts_per_group) { + // Get the top2 per thread + T largest = -INFINITY; + T second_largest = -INFINITY; + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + T value = input[i]; + if (value > largest) { + second_largest = largest; + largest = value; + } else if (value > second_largest) { + second_largest = value; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + largest = input[i]; + } + } + + __syncwarp(); // Ensure all threads have valid data before reduction + // Get the top2 warpwise + T max1 = cg::reduce(tile, largest, cg::greater()); + + T max2 = max1; + bool equal_to_max1 = (max1 == largest); + + int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); + + if (count_max1 == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + + if (lane_id == 0) { + *output = max1 + max2; + } +} + +template +__global__ void topk_with_k2_kernel(T* output, T* input, + int64_t const num_tokens, + int64_t const num_cases, + int64_t const n_group, + int64_t const num_experts_per_group) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + + int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; + if (case_id < num_cases) { + input += case_id * num_experts_per_group; + output += case_id; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +__global__ void group_idx_and_topk_idx_kernel( + T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, + T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, + int64_t const topk_group, int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool renormalize, + double routed_scaling_factor) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; + + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf); + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + s_topk_idx += warp_id * topk; + + T value = cuda::std::numeric_limits::min(); + T topk_group_value = cuda::std::numeric_limits::min(); + int32_t num_equalto_topkth_group; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = cuda::std::numeric_limits::min(); + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = __popc(__ballot_sync( + FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, -INFINITY); + + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = + (topk_group_value != cuda::std::numeric_limits::min()); + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : cuda::std::numeric_limits::min(); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = + i < topk + ? scores[s_topk_idx[i]] + : cuda_cast(0.0f); // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += reduce(tile, cuda_cast(value), cg::plus()); + } + } + + __syncthreads(); + + if (case_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + topk_indices[i] = s_topk_idx[i]; + topk_values[i] = cuda_cast(value); + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + topk_indices[i] = i; + topk_values[i] = cuda_cast(1.0f / topk); + } + } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, + IdxT* topk_indices, T* scores_with_bias, + int64_t const num_tokens, int64_t const num_experts, + int64_t const n_group, int64_t const topk_group, + int64_t const topk, bool const renormalize, + double const routed_scaling_factor, bool enable_pdl = false, + cudaStream_t const stream = 0) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); +} + +#define INSTANTIATE_NOAUX_TC(T, IdxT) \ + template void invokeNoAuxTc( \ + T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ + T * scores_with_bias, int64_t const num_tokens, \ + int64_t const num_experts, int64_t const n_group, \ + int64_t const topk_group, int64_t const topk, bool const renormalize, \ + double const routed_scaling_factor, bool enable_pdl, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float, int32_t); +INSTANTIATE_NOAUX_TC(half, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); +} // end namespace moe +} // namespace vllm + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor) { + auto data_type = scores_with_bias.scalar_type(); + auto input_size = scores_with_bias.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); + TORCH_CHECK(num_experts % n_group == 0, + "num_experts should be divisible by n_group"); + TORCH_CHECK(n_group <= 32, + "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + + torch::Tensor group_scores = torch::empty( + {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); + + switch (data_type) { + case torch::kFloat16: + // Handle Float16 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kFloat32: + // Handle Float32 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kBFloat16: + // Handle BFloat16 + vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), + num_tokens, num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + default: + // Handle other data types + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } + return {topk_values, topk_indices}; +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index dda189d51b417181bf2a89af3596a85621959fc9..a7a6e581f485fb3e90de82b93f635d873ccb379e 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -23,6 +23,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor); #endif bool moe_permute_unpermute_supported(); diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 2922352a3f7ccb224c8639095c4c04bc01fc652c..ca0c873f49d9f7ac49e7f4af4d203fcbfa9562bf 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -45,8 +45,6 @@ void moe_permute( auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); auto sorted_row_idx = torch::empty_like(inv_permuted_idx); - auto align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); CubKeyValueSorter sorter{}; int64_t* valid_num_ptr = nullptr; @@ -85,12 +83,14 @@ void moe_permute( }); // get m_indices and update expert_first_token_offset with align block - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); + // this is only required for DeepGemm and not required for CUTLASS group gemm if (align_block_size.has_value()) { - // update align_expert_first_token_offset + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); expert_first_token_offset.copy_(align_expert_first_token_offset); } } @@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, torch::Tensor& expert_first_token_offset, torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& m_indices) { - TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); + TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); } -void moe_unpermute(const torch::Tensor& input, - const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indices, - const std::optional& expert_map, - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& permuted_input, - torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, + const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx, + const std::optional& expert_first_token_offset, int64_t topk, + torch::Tensor& hidden_states) { TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); } @@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} +} \ No newline at end of file diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 99c52ef17d08b052f35fd8033ab87e5011c8cb72..cd80bfda7dfdeeda30be466ad2cead6f2dc61b07 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -573,7 +573,7 @@ void topk_softmax( stream); } else { - assert(topk_indices.scalar_type() == at::ScalarType::Int64); + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index b4e3efa6cd2612afda54d73ffdecd114c48efa1f..c765361c40d75b49c5c8a6bc7f7321d1fc574565 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -88,6 +88,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "output_tensor) -> ()"); m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + // Apply grouped topk routing to select experts. + m.def( + "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " + "topk_group, int topk, bool renormalize, float " + "routed_scaling_factor) -> (Tensor, Tensor)"); + m.impl("grouped_topk", torch::kCUDA, &grouped_topk); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index c727042ecf68554242203afd1151160db9ffdbe9..68720c43de4de6bf8162b32908e2ed56ec63d46a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -256,6 +256,14 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); // void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, // torch::Tensor& scale); +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_nvfp4_quant(torch::Tensor& out, + torch::Tensor& output_block_scale, + torch::Tensor& input, + torch::Tensor& input_global_scale); +#endif + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -363,6 +371,11 @@ void get_cutlass_moe_mm_data( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu new file mode 100644 index 0000000000000000000000000000000000000000..fdac47c425d618d7f31350bd21f3c019e0bafcf8 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -0,0 +1,418 @@ +// +// Based off of: +// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +// + +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" + +#include "core/registration.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "cutlass_extensions/common.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm::cutlass_w4a8 { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using MmaType = cutlass::float_e4m3_t; // A/scale element type +using QuantType = cutlass::int4b_t; // B element type (packed int4) + +static int constexpr TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr ScalePackSize = 8; // pack 8 scale elements together +static int constexpr PackFactor = 8; // 8 4-bit packed into int32 + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) +using StrideA = cutlass::detail::TagToStrideA_t; + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reordered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout, StrideB>{})); + +// Group-wise scales +using ElementScale = MmaType; +using LayoutScale = cutlass::layout::RowMajor; + +// Per-tok, per-chan scales +using ElementSChannel = float; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch + // based on the default + // setting in the + // Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +// ---------------------------------------------------------------------------- +// Kernel template — Tile/Cluster shapes +// ---------------------------------------------------------------------------- +template +struct W4A8GemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // Epilogue per-tok, per-chan scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C + // matrix. We can enable this if beta == 0 by changing ElementC to + // void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, + AlignmentC, ElementD, + typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule, // This is the only epi supporting the required + // swap + transpose. + EVTCompute>::CollectiveOp; + + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopShuffled, CollectiveEpilogue>; + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::StrideC; + using StrideD = typename GemmKernelShuffled::StrideD; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + + static torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type) { + // TODO: param validation + int m = A.size(0); + int k = A.size(1); + int n = B.size(1); + + // Allocate output + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto device = A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor D = + torch::empty({m, n}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + // prepare arg pointers + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.data_ptr()); + // can we avoid harcode the 8 here + auto S_ptr = + static_cast const*>( + group_scales.const_data_ptr()); + + // runtime layout for B + auto shape_B = cute::make_shape(n, k, 1); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + // strides + int const scale_k = cutlass::ceil_div(k, group_size); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + // Reverse stride here due to swap and transpose + StrideD stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(n, scale_k, 1)); + + // Create a structure of gemm kernel arguments suitable for invoking an + // instance of Gemm auto arguments = + // args_from_options(options); + /// Populates a Gemm::Arguments structure from the given arguments + /// Swap the A and B tensors, as well as problem shapes here. + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + + MainloopArguments mainloop_arguments{ + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size}; + + EpilogueArguments epilogue_arguments{ + ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), + nullptr, + {}, // no C + D_ptr, + stride_D}; + + Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, // shape + mainloop_arguments, + epilogue_arguments}; + + // Workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + + return D; + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Kernel_256x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x256_2x1x1 = + W4A8GemmKernel, Shape<_2, _1, _1>>; +using Kernel_128x256_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; + +torch::Tensor mm_dispatch(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + const std::string& schedule) { + if (schedule == "256x128_1x1x1") { + return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x64_1x1x1") { + return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x32_1x1x1") { + return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x16_1x1x1") { + return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_2x1x1") { + return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_1x1x1") { + return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x128_1x1x1") { + return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x64_1x1x1") { + return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x32_1x1x1") { + return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x16_1x1x1") { + return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } + TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); + return {}; +} + +torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + std::optional maybe_schedule) { + // requested a specific schedule + if (maybe_schedule) { + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, *maybe_schedule); + } + std::string schedule; + int M = A.size(0); + int K = A.size(1); + int N = B.size(1); + // heuristic + if (M <= 16) { + schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; + } else if (M <= 32) { + schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; + } else if (M <= 64) { + if (K == 16384 && N == 18432) + schedule = "256x64_1x1x1"; + else if (N <= 8192 && K <= 8192) + schedule = "128x32_1x1x1"; + else + schedule = "128x64_1x1x1"; + } else if (M <= 128) { + if (K == 16384 && N == 18432) + schedule = "256x128_1x1x1"; + else if (N <= 8192) + schedule = "128x64_1x1x1"; + else + schedule = "128x128_1x1x1"; + } else if (M <= 256) { + if (N <= 4096) + schedule = "128x64_1x1x1"; + else if (N <= 8192) + schedule = "128x128_1x1x1"; + else + schedule = "128x256_1x1x1"; + } else if (M <= 512 && N <= 4096) { + schedule = "128x128_1x1x1"; + } else if (M <= 1024) { + schedule = "128x256_1x1x1"; + } else { + schedule = "128x256_2x1x1"; + } + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, schedule); +} + +// ---------------------------------------------------------------------------- +// Pre-processing utils +// ---------------------------------------------------------------------------- +torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { + TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(scales.is_cuda()); + + auto packed_scales = torch::empty( + {scales.numel() * ScalePackSize}, + torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); + auto scales_ptr = static_cast(scales.const_data_ptr()); + auto packed_scales_ptr = + static_cast*>( + packed_scales.data_ptr()); + + cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); + + return packed_scales; +} + +torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { + TORCH_CHECK(B.dtype() == torch::kInt32); + TORCH_CHECK(B.dim() == 2); + + torch::Tensor B_packed = torch::empty_like(B); + + int k = B.size(0) * PackFactor; // logical k + int n = B.size(1); + + auto B_ptr = static_cast(B.const_data_ptr()); + auto B_packed_ptr = static_cast(B_packed.data_ptr()); + auto shape_B = cute::make_shape(n, k, 1); + auto layout_B = make_layout(shape_B, LayoutRight{}); // row major + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); + + return B_packed; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_mm", &mm); + m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); + m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8 \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 6c6e89790847f22158d24b2f3321a4da4b21cf90..15bb2c300543cc3aa2a1d5db6852baa457d82fac 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -10,7 +10,7 @@ template __global__ void get_group_gemm_starts( - int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, ElementAB* b_base_as_int, ElementC* out_base_as_int, @@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts( else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ get_group_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ static_cast(a_ptrs.data_ptr()), \ static_cast(b_ptrs.data_ptr()), \ static_cast(out_ptrs.data_ptr()), \ @@ -61,6 +61,8 @@ void run_get_group_gemm_starts( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); int num_experts = static_cast(expert_offsets.size(0)); bool per_act_token = a_scales.numel() != 1; diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 100f4850844443591c8fa202edfcf53319b39a43..49cafcc32adc61d04ec52e1fb398df4552f6eda8 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, } } +namespace { +inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& atomic_buffer, + int64_t num_experts, int64_t n, + int64_t k, cudaStream_t stream, + const bool swap_ab) { + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + + const int32_t* topk_ptr = static_cast(topk_ids.data_ptr()); + int32_t* ps1_ptr = static_cast(problem_sizes1.data_ptr()); + int32_t* ps2_ptr = static_cast(problem_sizes2.data_ptr()); + int32_t* atomic_ptr = static_cast(atomic_buffer.data_ptr()); + + if (swap_ab) { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } else { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } +} +} // namespace + +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + // Swap-AB should be disabled for FP4 path + bool may_swap_ab = (!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD); + + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); +} + void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller( bool may_swap_ab = (!blockscale_offsets.has_value()) && (topk_ids.numel() <= SWAP_AB_THRESHOLD); - if (may_swap_ab) { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } else { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); if (blockscale_offsets.has_value()) { // fp4 path diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 106bacb4883cb26333a45b39dfbd29b33e32fd21..84843ee6e094927a9d1a5ad7206b2a9479ffdb48 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data( version_num, ". Required capability: 90 or 100"); } +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + int32_t version_num = get_sm_version_num(); +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) + get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, + problem_sizes2, num_experts, n, k, + blockscale_offsets); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " + "kernel for CUDA device capability: ", + version_num, ". Required capability: 90 or 100"); +} + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..9bbeb0334fb9a81c21c3a461a5dde86ede759444 --- /dev/null +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" + +namespace vllm { + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = c10::Half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = c10::BFloat16; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +template +__inline__ __device__ PackedVec compute_silu(PackedVec& vec, + PackedVec& vec2) { + PackedVec result; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + if constexpr (std::is_same_v) { + half2 val(0.5f, 0.5f); + half2 t0 = __hmul2(vec.elts[i], val); + half2 t1 = __hfma2(h2tanh(t0), val, val); + half2 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } else { + __nv_bfloat162 val(0.5f, 0.5f); + __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); + __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); + __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } + } + return result; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, + PackedVec& vec2, + float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + PackedVec out_silu = compute_silu(vec, vec2); + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(out_silu.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(out_silu.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4( +#else +silu_and_cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = + rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + ; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numCols, SFout); + + out_pos = silu_and_cvt_warp_fp16_to_fp4( + in_vec, in_vec2, SFScaleVal, sf_out); + } + } +#endif +} + +} // namespace vllm + +void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { + TORCH_CHECK(input.dtype() == torch::kFloat16 || + input.dtype() == torch::kBFloat16); + int32_t m = input.size(0); + int32_t n = input.size(1) / 2; + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "act_and_mul_quant_kernel", [&] { + auto input_ptr = reinterpret_cast(input.data_ptr()); + VLLM_DISPATCH_BYTE_TYPES( + output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type", + [&] { + vllm::silu_and_cvt_fp16_to_fp4 + <<>>( + m, n, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + }); +} diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 9af7833d09f32b528f39d5c56a5c899f6d2e6504..0d14ba15937c65a26d5936fa9fc42eb53eeec736 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -349,9 +349,12 @@ def to_cute_constant(value: list[int]): def unique_schedules(impl_configs: list[ImplConfig]): - return list( - set(sch for impl_config in impl_configs - for sch in impl_config.schedules)) + # Use dict over set for deterministic ordering + return list({ + sch: None + for impl_config in impl_configs + for sch in impl_config.schedules + }.keys()) def unsigned_type_with_bitwidth(num_bits): @@ -568,78 +571,79 @@ def generate(): itertools.repeat(default_heuristic)) ] - # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) - # TODO (LucasWilkinson): Further tuning required - qqq_tile_heuristic_config = { - #### M = 257+ - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), - # "M > 256": ((128, 256), (2, 1, 1)), - "M > 256": ((128, 128), (2, 1, 1)), - #### M = 129-256 - "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), - "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 128": ((128, 256), (2, 1, 1)), - "M > 128": ((128, 128), (2, 1, 1)), - #### M = 65-128 - "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), - "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), - "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), - "M > 64": ((128, 128), (2, 1, 1)), - #### M = 33-64 - "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), - # Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), - "M > 32": ((128, 64), (2, 1, 1)), - #### M = 17-32 - "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), - "M > 16": ((256, 32), (2, 1, 1)), - #### M = 1-16 - "N >= 26624": ((256, 16), (1, 1, 1)), - None: ((128, 16), (1, 1, 1)), - } - - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - qqq_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore - for cond, tile_config in qqq_tile_heuristic_config.items() - ] - - QQQ_kernel_types = [ - *(TypeConfig( - a=DataType.s8, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.s32, - ) for b_group_scale in (DataType.f16, DataType.void)), - *(TypeConfig( - a=DataType.e4m3, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.f32, - ) for b_group_scale in (DataType.f16, DataType.void)), - ] - - impl_configs += [ - ImplConfig(x[0], x[1], x[2]) - for x in zip(QQQ_kernel_types, - itertools.repeat(get_unique_schedules(qqq_heuristic)), - itertools.repeat(qqq_heuristic)) - ] + # TODO: Support W4A8 when ready + # # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # # TODO (LucasWilkinson): Further tuning required + # qqq_tile_heuristic_config = { + # #### M = 257+ + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # # "M > 256": ((128, 256), (2, 1, 1)), + # "M > 256": ((128, 128), (2, 1, 1)), + # #### M = 129-256 + # "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + # "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 128": ((128, 256), (2, 1, 1)), + # "M > 128": ((128, 128), (2, 1, 1)), + # #### M = 65-128 + # "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + # "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + # "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + # "M > 64": ((128, 128), (2, 1, 1)), + # #### M = 33-64 + # "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # # Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + # "M > 32": ((128, 64), (2, 1, 1)), + # #### M = 17-32 + # "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + # "M > 16": ((256, 32), (2, 1, 1)), + # #### M = 1-16 + # "N >= 26624": ((256, 16), (1, 1, 1)), + # None: ((128, 16), (1, 1, 1)), + # } + + # # For now we use the same heuristic for all types + # # Heuristic is currently tuned for H100s + # qqq_heuristic = [ + # (cond, ScheduleConfig(*tile_config, + # **sch_common_params)) # type: ignore + # for cond, tile_config in qqq_tile_heuristic_config.items() + # ] + + # QQQ_kernel_types = [ + # *(TypeConfig( + # a=DataType.s8, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.s32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # *(TypeConfig( + # a=DataType.e4m3, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.f32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # ] + + # impl_configs += [ + # ImplConfig(x[0], x[1], x[2]) + # for x in zip(QQQ_kernel_types, + # itertools.repeat(get_unique_schedules(qqq_heuristic)), + # itertools.repeat(qqq_heuristic)) + # ] output_dir = os.path.join(SCRIPT_DIR, "generated") diff --git a/csrc/quantization/marlin/dense/LICENSE b/csrc/quantization/marlin/dense/LICENSE deleted file mode 100644 index 1d1e4cf9c8233fabb9ed0e47b4ecdaff13bb89b4..0000000000000000000000000000000000000000 --- a/csrc/quantization/marlin/dense/LICENSE +++ /dev/null @@ -1,209 +0,0 @@ -Contains code from https://github.com/IST-DASLab/marlin - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. diff --git a/csrc/quantization/marlin/dense/common/base.h b/csrc/quantization/marlin/dense/common/base.h deleted file mode 100644 index 68c83d5478cf892fa50a5a06a95ef0f572fb6172..0000000000000000000000000000000000000000 --- a/csrc/quantization/marlin/dense/common/base.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; diff --git a/csrc/quantization/marlin/dense/common/mem.h b/csrc/quantization/marlin/dense/common/mem.h deleted file mode 100644 index 64f9c393d77ced6b5c61d336c77ee8dad9971f56..0000000000000000000000000000000000000000 --- a/csrc/quantization/marlin/dense/common/mem.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu deleted file mode 100644 index ea96326ed7e61e95d27d12b26e02eaece0c073a8..0000000000000000000000000000000000000000 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_dense { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - auto s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin_dense - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_dense::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_dense::tile_size)); - TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_dense::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_dense::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * - marlin_dense::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_dense::min_thread_n)); - int min_workspace_size = - (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, marlin_dense::max_par); - - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm", &marlin_gemm); -} diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu deleted file mode 100644 index c96d68d9b29aaa38b27bfe82e9bd18d118595505..0000000000000000000000000000000000000000 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ /dev/null @@ -1,1248 +0,0 @@ -/* - * Adapted from - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp - * Modified by HandH1998 - * Copyright (C) 2024 HandH1998 - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "../dense/common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "../dense/common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS_GROUP = Vec; // weight per-group quantization scales -using FragS_CHANNEL = - Vec; // weight per-channel quantization scales or activaton - // per-token quantization scales - -// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, -// cp.async.ca can support BYTES = 4, 8, 16; -// as s_tok's shape is equal to prob_m, we need set s_tok to float type, -// and cp_size = 1 float, i.e., 4 BYTES -// Asynchronous global->shared copy for activation quantizaton scales s_tok -__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 4; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.ca.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// m16n8k16 tensor core mma instruction with int8 inputs and int32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - int* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), - "r"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in int8 tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -inline __device__ half2 float2_to_half2(float2 f) { - uint32_t res; - // NOTE(HandH1998): h0,h1 should be uint16_t, not half - uint16_t h0, h1; - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); - return reinterpret_cast(res); -} - -inline __device__ float int32_to_float(int h) { - float res; - asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); - return res; -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per channel dequant. -__device__ inline FragB dequant_per_channel(int q) { - static constexpr int MASK = 0xf0f0f0f0; - FragB frag_b; - frag_b[0] = (q & MASK); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per group dequant. -__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { - static constexpr uint32_t LO = 0x000f000f; - static constexpr uint32_t HI = 0x00f000f0; - static constexpr uint32_t EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - static constexpr uint32_t SUB = 0x64086408; - static constexpr uint32_t MUL = 0x2c002c00; - static constexpr uint32_t ADD = 0xd480d480; - *reinterpret_cast(&t0) = __hsub2( - *reinterpret_cast(&t0), *reinterpret_cast(&SUB)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - - uint16_t s = reinterpret_cast(&frag_s)[i]; - uint32_t double_s; - // pack 2xfp16 to half2 - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); - // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 - // half, respectively) - static constexpr uint32_t MAGIC_NUM = 0x64806480; - *reinterpret_cast(&t0) = __hfma2( - *reinterpret_cast(&t0), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 - // int8 into 1 uint32 - FragB frag_b; - uint32_t uint8s; - static constexpr uint32_t MASK_0246 = 0x6420; - static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(uint8s) - : "r"(t0), "r"(t1), "n"(MASK_0246)); - frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); - return frag_b; -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if constexpr (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; - D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 16; - C += 16 * thread_m_blocks * prob_n / 4; - D += 16 * thread_m_blocks * prob_n / 8; - s_tok += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 16; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 1 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - constexpr int s_tok_sh_stride = 16 * thread_m_blocks; - - constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4; - - int s_group_gl_stride = prob_n / 8; - constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_group_sh_stage = s_group_sh_stride; - int s_group_gl_rd_delta = s_group_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); - a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - auto s_tok_gl_rd = threadIdx.x; - // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, - // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for - // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as - // s_tok's size is not fixed, we can not shuffle before inference we shuffle - // it when fetching s_tok from global memory to shared memory, that's why - // s_tok_sh_wr is like this - int s_tok_sh_wr = - (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; - int s_tok_sh_rd = (threadIdx.x % 32) / 4; - bool s_tok_sh_wr_pred = threadIdx.x < prob_m; - - auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - auto s_ch_sh_wr = threadIdx.x; - int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - 2 * ((threadIdx.x % 32) % 4); - bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; - - int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd; - bool s_group_sh_wr_pred; - if constexpr (group_blocks != -1) { - s_group_gl_rd = - s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_group_sh_stride * slice_col + threadIdx.x; - s_group_sh_wr = threadIdx.x; - // NOTE(HandH1998): s_group_sh_rd is related to mma output C - s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages * - // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s_tok = sh_b + (stages * b_sh_stage); - int4* sh_s_ch = sh_s_tok + s_tok_sh_stride; - int4* sh_s_group = sh_s_ch + s_ch_sh_stride; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS_GROUP frag_s_group[2][4]; - FragS_CHANNEL frag_s_tok[thread_m_blocks]; - FragS_CHANNEL frag_s_ch[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe; - if (s_group_sh_wr_pred) - cp_async4(&sh_s_group_stage[s_group_sh_wr], - &s_group[s_group_gl_rd]); - s_group_gl_rd += s_group_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - int4* sh_s_group_stage = - sh_s_group + - s_group_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s_group[k % 2])[0] = - sh_s_group_stage[s_group_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - // int b_quant_shift = b_quant << 4; - FragB frag_b0, frag_b1; - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0); - frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1); - } else { - int b_quant_shift = b_quant << 4; - frag_b0 = dequant_per_channel(b_quant); - frag_b1 = dequant_per_channel(b_quant_shift); - } - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - int* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - // global_reduce works on INT32 elements, which are the results of INT8 GEMM. - // This is why we need another INT32 maxtrix `C` to reduce instead of the - // original half matrix `D`. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 4; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 8 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; - c_gl_wr += (4 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads * 2; - auto c_sh_wr = 2 * threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i + 1], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2) + 1], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; - int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - reinterpret_cast(&d_red1)[j]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += - reinterpret_cast(&d_red2)[j]; - } - } - if (!last) { - int4 d1, d2; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d1)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d2)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - d1; - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + - 1] = d2; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int d_gl_stride = prob_n / 8; - constexpr int d_sh_stride = 2 * thread_n_blocks + 1; - int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int d_sh_rd_delta = - d_sh_stride * (threads / (2 * thread_n_blocks)); - - int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - d_gl_wr += (2 * thread_n_blocks) * slice_col; - int d_sh_wr = - (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - d_sh_wr += 32 * (threadIdx.x / 32); - int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int d_gl_wr_end = d_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { - float2 deq_res; - deq_res.x = int32_to_float(c0) * w_s[0] * a_s; - deq_res.y = int32_to_float(c1) * w_s[1] * a_s; - ((half2*)sh)[idx] = float2_to_half2(deq_res); - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = d_sh_wr + 8 * j; - write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - } - d_sh_wr += 16 * (4 * d_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (d_gl_wr < d_gl_wr_end) { - D[d_gl_wr] = sh[d_sh_rd]; - d_gl_wr += d_gl_wr_delta; - d_sh_rd += d_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (last) { - if (s_tok_sh_wr_pred) { - cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]); - } - if (s_ch_sh_wr_pred) { - cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]); - } - cp_async_fence(); - } - thread_block_reduce(); - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - frag_s_tok[i][0] = - *reinterpret_cast(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]); - frag_s_tok[i][1] = *reinterpret_cast( - &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]); - } - reinterpret_cast(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0]; - reinterpret_cast(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1]; - reinterpret_cast(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8]; - reinterpret_cast(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x; - s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \ - prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D, - void* s_tok, void* s_ch, void* s_group, int prob_m, - int prob_n, int prob_k, void* workspace, - int groupsize = -1, int dev = 0, cudaStream_t stream = 0, - int thread_k = -1, int thread_n = -1, int sms = -1, - int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - int4* D_ptr = (int4*)D; - const float* s_tok_ptr = (const float*)s_tok; - const int4* s_ch_ptr = (const int4*)s_ch; - const int4* s_group_ptr = (const int4*)s_group; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; - D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - s_tok_ptr += 16 * thread_m_blocks * par; - } -} -} // anonymous namespace - -torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, - torch::Tensor const& b_q_weight, - torch::Tensor const& s_tok, - torch::Tensor const& s_ch, - torch::Tensor const& s_group, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(size_m == s_tok.numel(), - "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(tile_size)); - TORCH_CHECK( - (size_k / tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + - ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size)); - - int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0); - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify N - TORCH_CHECK(s_ch.numel() == size_n, - "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) + - ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(tile_size)); - if (groupsize != -1) { - TORCH_CHECK(s_group.size(1) == size_n, - "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - size_k % s_group.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by s_group.size(0) = " + str(s_group.size(0))); - } - - int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "Shape mismatch: size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify s_tok device, strides and dtype - TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU"); - TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous"); - TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32"); - - // Verify s_ch device, strides and dtype - TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU"); - TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous"); - TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32"); - - // Verify s_group device, strides and dtype - TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU"); - TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous"); - TORCH_CHECK(s_group.dtype() == torch::kFloat16, - "s_group's dtype is not float16"); - - // Verify workspace size - TORCH_CHECK(size_n % min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(min_thread_n)); - int min_workspace_size = (size_n / min_thread_n) * max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c); - - // Alloc D matrix - auto options_d = - torch::TensorOptions().dtype(torch::kFloat16).device(a.device()); - torch::Tensor d = torch::empty({size_m, size_n}, options_d); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - int dev = a.get_device(); - marlin_qqq_cuda( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(), - s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par); - - return d; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_qqq_gemm", &marlin_qqq_gemm); -} diff --git a/csrc/quantization/vectorization_utils.cuh b/csrc/quantization/vectorization_utils.cuh index 8aa0147df6ba8977d642684e24a9ee69d32a7bf0..98b491b7e23fc0e15820390d07254f3ec160fe0d 100644 --- a/csrc/quantization/vectorization_utils.cuh +++ b/csrc/quantization/vectorization_utils.cuh @@ -41,8 +41,10 @@ __device__ inline void vectorize_with_alignment( for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } return; } @@ -71,8 +73,10 @@ __device__ inline void vectorize_with_alignment( // 2. vectorize the main part for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } // 3. handle the tail @@ -125,7 +129,8 @@ __device__ inline void vectorize_read_with_alignment(const InT* in, int len, auto* v_in = reinterpret_cast(in); for (int i = tid; i < num_vec; i += stride) { - vec_op(v_in[i]); + vin_t tmp = v_in[i]; + vec_op(tmp); } return; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c1025eac9d296ed884d14b8d88834239be06e1c5..1d234c661f89fa4e3652c9f09dca82d80f0f38e1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -267,6 +267,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); // ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + ops.def( + "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " + "Tensor input, Tensor input_global_scale) -> ()"); + ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); +#endif + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -422,14 +430,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // custom types: // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA - // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def( - "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " - "Tensor", - {stride_tag}); - // conditionally compiled so impl in source file - // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " @@ -498,6 +498,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file + + // CUTLASS w4a8 GEMM + ops.def( + "cutlass_w4a8_mm(" + " Tensor A," + " Tensor B," + " Tensor group_scales," + " int group_size," + " Tensor channel_scales," + " Tensor token_scales," + " ScalarType? out_type," + " str? maybe_schedule" + ") -> Tensor", + {stride_tag}); + // pack scales + ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); + // encode and reorder weight matrix + ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. @@ -534,15 +554,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // marlin_qqq_gemm for QQQ. - ops.def( - "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " - "Tensor s_tok, Tensor s_ch, Tensor s_group, " - "Tensor! workspace, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // CUTLASS nvfp4 block scaled GEMM ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," @@ -621,6 +632,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); + // A function that computes problem sizes for each expert's multiplication + // used by the two mms called from fused MoE operation. It takes topk_ids as + // an input, and computes problem_sizes1 and problem_sizes2 only. + ops.def( + "get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, " + " Tensor! problem_sizes1, " + " Tensor! problem_sizes2, " + " int num_experts, int n, int k, " + " Tensor? blockscale_offsets) -> ()", + {stride_tag}); + ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, + &get_cutlass_moe_mm_problem_sizes); + // A function that computes data required to run fused MoE with w8a8 grouped // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs // as an input, and computes expert_offsets (token start indices of each @@ -877,17 +901,37 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor scale) -> ()"); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); + cache_ops.def( + "cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe," + " Tensor cp_local_token_select_indices," + " Tensor! kv_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor scale) -> ()"); + cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA, + &cp_fused_concat_and_cache_mla); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); - // Gather cache blocks from src_cache to dst. + // Gather cache blocks from src_cache to dst, dequantizing from + // src_cache's dtype to dst's dtype if necessary. + cache_ops.def( + "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " + " Tensor block_table, Tensor cu_seq_lens, " + " int batch_size, " + " str kv_cache_dtype, " + " Tensor scale, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, + &gather_and_maybe_dequant_cache); + cache_ops.def( - "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " + "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); - cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); + cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/docker/Dockerfile b/docker/Dockerfile index 74938917781acac8a735215e903ff126cb158783..2e272cbca84177eed611852dd5d5e5f9d02f2b2b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -372,31 +372,45 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt -# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel. -ARG FLASHINFER_GIT_REF="v0.2.11" +# Keep this in sync with "flashinfer" extra in setup.py +ARG FLASHINFER_GIT_REF="v0.2.14.post1" +# Flag to control whether to compile FlashInfer AOT kernels +# Set to "true" to enable AOT compilation: +# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... +ARG FLASHINFER_AOT_COMPILE=false RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' . /etc/environment git clone --depth 1 --recursive --shallow-submodules \ --branch ${FLASHINFER_GIT_REF} \ ${FLASHINFER_GIT_REPO} flashinfer - # Exclude CUDA arches for older versions (11.x and 12.0-12.7) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - if [[ "${CUDA_VERSION}" == 11.* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - else - # CUDA 12.8+ supports 10.0a and 12.0 - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" - fi - echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}" - # Needed to build AOT kernels pushd flashinfer - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer.aot - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - uv pip install --system --no-build-isolation --force-reinstall --no-deps . + if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then + # Exclude CUDA arches for older versions (11.x and 12.0-12.7) + # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. + if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" + elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" + else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" + fi + echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" + # Build AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer.aot + # Install with no-build-isolation since we already built AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + uv pip install --system --no-build-isolation . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + # Download pre-compiled cubins + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." + else + echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode" + uv pip install --system . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + fi popd rm -rf flashinfer BASH @@ -418,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # Install DeepGEMM from source -ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" -RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' - . /etc/environment - CUDA_MAJOR="${CUDA_VERSION%%.*}" - CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" - CUDA_MINOR="${CUDA_MINOR%%.*}" - if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then - git clone --recursive --shallow-submodules \ - ${DEEPGEMM_GIT_REPO} deepgemm - echo "🏗️ Building DeepGEMM" - pushd deepgemm - git checkout ${DEEPGEMM_GIT_REF} - # Build DeepGEMM - # (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) - rm -rf build dist - rm -rf *.egg-info - python3 setup.py bdist_wheel - uv pip install --system dist/*.whl - popd - rm -rf deepgemm - else - echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" - fi -BASH +COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh +RUN --mount=type=cache,target=/root/.cache/uv \ + VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \ + && rm /tmp/install_deepgemm.sh + +# Install EP kernels(pplx-kernels and DeepEP), NixL +COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh +COPY tools/install_nixl.sh install_nixl.sh +ENV CUDA_HOME=/usr/local/cuda +RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \ + && bash install_python_libraries.sh \ + && bash install_nixl.sh --force #################### vLLM installation IMAGE #################### diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 4f40f32a39f26a5352b8806ba59382cc085fcbdc..f1648573250434c230b9b2a7219ef116c47e18ea 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install lm-eval[api]==0.4.4 \ + && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ && python3 -m pip install pytest-shard # ----------------------- diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 4e89bb3057c5ee6cd9b875d85cf6bb2d6f9e5f75..9270b48c54d4b173c357dc07a853464a44eacbf6 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,7 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \ microdnf clean all # Python Installation @@ -136,6 +136,71 @@ RUN --mount=type=cache,target=/root/.cache/uv \ mkdir -p /tmp/hf-xet/dist && \ cp dist/*.whl /tmp/hf-xet/dist/ +# Build numba +FROM python-install AS numba-builder + +ARG MAX_JOBS +ARG NUMBA_VERSION=0.61.2 + +WORKDIR /tmp + +# Clone all required dependencies +RUN --mount=type=cache,target=/root/.cache/uv \ + microdnf install ninja-build gcc gcc-c++ -y && \ + git clone --recursive https://github.com/llvm/llvm-project.git -b llvmorg-15.0.7 && \ + git clone --recursive https://github.com/numba/llvmlite.git -b v0.44.0 && \ + git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \ + cd llvm-project && mkdir build && cd build && \ + uv pip install 'cmake<4' setuptools numpy && \ + export PREFIX=/usr/local && CMAKE_ARGS="${CMAKE_ARGS} -DLLVM_ENABLE_PROJECTS=lld;libunwind;compiler-rt" \ + CFLAGS="$(echo $CFLAGS | sed 's/-fno-plt //g')" \ + CXXFLAGS="$(echo $CXXFLAGS | sed 's/-fno-plt //g')" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_INCLUDE_DIR=$PREFIX/include" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_LIBRARY_DIR=$PREFIX/lib" \ + cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_LIBRARY_PATH="${PREFIX}" \ + -DLLVM_ENABLE_LIBEDIT=OFF \ + -DLLVM_ENABLE_LIBXML2=OFF \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_TERMINFO=OFF \ + -DLLVM_INCLUDE_BENCHMARKS=OFF \ + -DLLVM_INCLUDE_DOCS=OFF \ + -DLLVM_INCLUDE_EXAMPLES=OFF \ + -DLLVM_INCLUDE_GO_TESTS=OFF \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_UTILS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DLLVM_UTILS_INSTALL_DIR=libexec/llvm \ + -DLLVM_BUILD_LLVM_DYLIB=OFF \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD=WebAssembly \ + -DLLVM_ENABLE_FFI=ON \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_OPTIMIZED_TABLEGEN=ON \ + -DCMAKE_POLICY_DEFAULT_CMP0111=NEW \ + -DCOMPILER_RT_BUILD_BUILTINS=ON \ + -DCOMPILER_RT_BUILTINS_HIDE_SYMBOLS=OFF \ + -DCOMPILER_RT_BUILD_LIBFUZZER=OFF \ + -DCOMPILER_RT_BUILD_CRT=OFF \ + -DCOMPILER_RT_BUILD_MEMPROF=OFF \ + -DCOMPILER_RT_BUILD_PROFILE=OFF \ + -DCOMPILER_RT_BUILD_SANITIZERS=OFF \ + -DCOMPILER_RT_BUILD_XRAY=OFF \ + -DCOMPILER_RT_BUILD_GWP_ASAN=OFF \ + -DCOMPILER_RT_BUILD_ORC=OFF \ + -DCOMPILER_RT_INCLUDE_TESTS=OFF \ + ${CMAKE_ARGS} -GNinja ../llvm \ + + && ninja install . && \ + # build llvmlite + cd ../../llvmlite && python setup.py bdist_wheel && \ + cd ../numba && \ + if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \ + sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ + fi && python setup.py bdist_wheel + + # Final build stage FROM python-install AS vllm-cpu ARG PYTHON_VERSION @@ -163,23 +228,30 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ sed -i '/^torch/d' requirements/build.txt && \ - ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ - VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ - HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ - TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ + ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ + VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ + HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl) && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ + LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ + NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ $HF_XET_WHL_FILE \ $TORCH_WHL_FILE \ + $LLVM_WHL_FILE \ + $NUMBA_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ - -r requirements/cpu.txt + -r requirements/cpu.txt + # Build and install vllm RUN --mount=type=cache,target=/root/.cache/uv \ - VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ + VLLM_TARGET_DEVICE=cpu VLLM_CPU_MOE_PREPACK=0 python setup.py bdist_wheel && \ uv pip install "$(echo dist/*.whl)[tensorizer]" # setup non-root user for vllm @@ -196,4 +268,3 @@ WORKDIR /home/vllm # Set the default entrypoint ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] - diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index 21901513697611dfc79db432417e95d154316c51..ca2d7833c1efa9664abee1ecaebf1a91af68721c 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -7,7 +7,8 @@ WORKDIR /workspace/vllm # Install some basic utilities RUN apt-get update && apt-get install -y \ git \ - ffmpeg libsm6 libxext6 libgl1 + ffmpeg libsm6 libxext6 libgl1 && \ + rm -rf /var/lib/apt/lists/* # Build vLLM. COPY . . @@ -16,6 +17,9 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi # Remove existing versions of dependencies +# TODO: These packages will remain as dead weight in the Docker image layers. +# We should find a way to build the image without uninstalling these. +# Consider using a different base image. RUN pip uninstall -y torch torch_xla torchvision ENV VLLM_TARGET_DEVICE="tpu" @@ -23,9 +27,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ -r requirements/tpu.txt -RUN python3 -m pip install -e . + +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e . # install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e tests/vllm_test_utils CMD ["/bin/bash"] diff --git a/docs/api/README.md b/docs/api/README.md index 327472df1d52c32404927ad03d3e2c2a315c4f99..57142e8f5625d6973ea3fcb3f8f157d18b571564 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -77,6 +77,7 @@ Internal data structures. - [vllm.multimodal.inputs.MultiModalFieldElem][] - [vllm.multimodal.inputs.MultiModalFieldConfig][] - [vllm.multimodal.inputs.MultiModalKwargsItem][] +- [vllm.multimodal.inputs.MultiModalKwargsItems][] - [vllm.multimodal.inputs.MultiModalKwargs][] - [vllm.multimodal.inputs.MultiModalInputs][] diff --git a/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png new file mode 100644 index 0000000000000000000000000000000000000000..185f61e6a3edeefb7f3c2bc5bcb942af3408da0e Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/full_attn.png b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png new file mode 100644 index 0000000000000000000000000000000000000000..30eade5c7051cc050ffd0ccae797fd083d4389b0 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png new file mode 100644 index 0000000000000000000000000000000000000000..bcffc27a716497b2b280be176b4f40e520dadda0 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/overview.png b/docs/assets/design/hybrid_kv_cache_manager/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..ac80581f491da6235c53f14601b74288ead3b185 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/overview.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png new file mode 100644 index 0000000000000000000000000000000000000000..10aa6146dc7ab478dc60c9383ab9fdb1c5de9d30 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png differ diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 36232e6ad96cc87196129250c1c523ba0e4ec187..221a7bd96213f76fc7066c0410d55c0981840579 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,8 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) +- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 058eba5fe0b1e8b9c1a1878b36301ff8579c9ee1..efda9c8e019ebfc1441a3b30e2581aa21695670b 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) +- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 2eeb8ad25de5f86c93aa89a8952125452fec2fb1..2d8cdcc11fa9970081eb5d08cd4c3a94e84781c0 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -48,7 +48,7 @@ You can tune the performance by adjusting `max_num_batched_tokens`: - Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes. - Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch. -- For optimal throughput, we recommend setting `max_num_batched_tokens > 8096` especially for smaller models on large GPUs. +- For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs. - If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes). ```python @@ -129,6 +129,56 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. +### Batch-level DP for Multi-Modal Encoders + +By default, TP is used to shard the weights of multi-modal encoders just like for language decoders, +in order to reduce the memory and compute load on each GPU. + +However, since the size of multi-modal encoders is very small compared to language decoders, +there is relatively little gain from TP. On the other hand, TP incurs significant communication +overhead because of all-reduce being performed after every layer. + +Given this, it may be advantageous to instead shard the batched input data using TP, essentially +performing batch-level DP. This has been shown to improve the throughput by around 10% for +`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, +batch-level DP can provide another 40% increase to throughput compared to regular TP. + +Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, +there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. + +You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example: + +```python +from vllm import LLM + +llm = LLM( + model="Qwen/Qwen2.5-VL-72B-Instruct", + tensor_parallel_size=4, + # When mm_encoder_tp_mode="data", + # the vision encoder uses TP=4 (not DP=1) to shard the input data, + # so the TP size becomes the effective DP size. + # Note that this is independent of the DP size for language decoder which is used in expert parallel setting. + mm_encoder_tp_mode="data", + # The language decoder uses TP=4 to shard the weights regardless + # of the setting of mm_encoder_tp_mode +) +``` + +!!! important + Batch-level DP is not to be confused with API request-level DP + (which is instead controlled by `data_parallel_size`). + +Batch-level DP needs to be implemented on a per-model basis, +and enabled by setting `supports_encoder_tp_data = True` in the model class. +Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature. + +Known supported models: + +- Llama4 () +- MiniCPM-V-2.5 or above (, ) +- Qwen2.5-VL () +- Step3 () + ## Input Processing ### Parallel Processing @@ -149,21 +199,41 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 !!! note API server scale-out is only available for online inference. +!!! warning + By default, 8 CPU threads are used in each API server to load media items (e.g. images) + from request data. + + If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT` + to avoid CPU resource exhaustion. + !!! note - [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled + API server scale-out disables [multi-modal IPC caching](#ipc-caching) because it requires a one-to-one correspondance between API and engine core processes. -## Multi-Modal Caching + This does not impact [multi-modal processor caching](#processor-caching). -### Processor Cache +## Multi-Modal Caching -By default, the multi-modal processor cache is enabled to avoid repeatedly processing -the same multi-modal inputs via Hugging Face `AutoProcessor`, +Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, which commonly occurs in multi-turn conversations. -You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` -(default 4 GiB per API process + 4 GiB per engine core process). -If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. +### Processor Caching + +Multi-modal processor caching is automatically enabled +to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. + +### IPC Caching + +Multi-modal IPC caching is automatically enabled when +there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes, +to avoid repeatedly transferring the same multi-modal inputs between them. + +### Configuration + +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). + +If you do not benefit much from the cache, you can disable both IPC +and processor caching completely via `mm_processor_cache_gb=0`. Examples: @@ -176,3 +246,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) ``` + +### Cache Placement + +Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: + +| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------| +| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | +| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| ❌ | ❌ | N/A | N/A | `0` | + +K: Stores the hashes of multi-modal items +V: Stores the processed tensor data of multi-modal items diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md index a93435ed71b50054bf2f371012c05b9d5b883392..e456077e04958a3b177f8d9f602e4ff060dbe6db 100644 --- a/docs/configuration/tpu.md +++ b/docs/configuration/tpu.md @@ -45,32 +45,32 @@ This initial compilation time ranges significantly and is impacted by many of th ### Optimize based on your data -#### max model len vs. most model len +#### max-model-len vs. most-model-len ![most_model_len](../assets/design/tpu/most_model_len.png) -If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most model len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. +If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most-model-len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. For example, 1% requests are 32k length and 99% requests are 2k length. You can pass 32k into `--max-model-len 32768` and use `VLLM_TPU_MOST_MODEL_LEN=2048`. -The requests get subdivided into max-model-len and most-model-len categories, for the latter category, we can gain better performance since the server can process more requests at a time. +The requests get subdivided into max-model-len and most-model-len categories, for the latter category, you can gain better performance since the server can process more requests at a time. #### Padding -For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128: 128, 256, etc. +For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128 (e.g., 128, 256, etc.) -The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about tpu padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: +The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about TPU padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: -1) the default exponential padding (pad to the nearest power of 2) -2) bucket padding (pad to the nearest linearly increasing bucket). +1. the default exponential padding (pad to the nearest power of 2) +2. bucket padding (pad to the nearest linearly increasing bucket). When using bucket padding, the buckets start from 16, end at max_model_len, and increment by `VLLM_TPU_BUCKET_PADDING_GAP`. For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]. -The fewer tokens we pad, the less unnecessary computation TPU does, the better performance we can get. For example, if num_tokens=300, with exponential padding, we pad to 512, with the bucket_padding above, we pad to 320. +The fewer tokens you pad, the less unnecessary computation TPU does, the better performance you can get. For example, if num_tokens=300, with exponential padding, you pad to 512, with the bucket_padding above, you pad to 320. -However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compilaed graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. +However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. #### Quantization diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 7ef22d6f8c3f5e505d0ee92b87c4b9c5e9a4d3eb..3dae62dd5d944c1450bb227768184af990a78eb3 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -90,7 +90,7 @@ address the long build time at its source, the current workaround is to set `VLL to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) when manually triggering a build on Buildkite. This branch accomplishes two things: -1. Increase the timeout limit to 10 hours so that the build doesn't timeout. +1. Increase the timeout limit to 10 hours so that the build doesn't time out. 2. Allow the compiled artifacts to be written to the vLLM sccache S3 bucket to warm it up so that future builds are faster. diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index 21b1f21d60a359490253d64c73caa2a2152bdeae..aafdb1058e03c087add4811a1ff1e672ef4056f5 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -121,3 +121,31 @@ To support a model with interleaving sliding windows, we need to take care of th - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). With these two steps, interleave sliding windows should work with the model. + +### How to support models that use Mamba? + +We consider 3 different scenarios: + +1. Models that use Mamba layers (either Mamba-1 or Mamba-2) but do not use attention layers. +2. Models that combine Mamba layers (either Mamba-1 or Mamba-2) together with attention layers. +3. Models that combine Mamba-like mechanisms (e.g., Linear Attention, ShortConv) together with attention layers. + +For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](gh-file:vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](gh-file:vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference. +The model should inherit protocol `IsAttentionFree` and also implement class methods `get_mamba_state_dtype_from_config` and `get_mamba_state_shape_from_config` to calculate the state shapes and data types from the config. +For the mamba layers themselves, please use the [`MambaMixer`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes. +Please *do not* use the `MambaCacheManager` (deprecated in V1) or replicate any of the V0-specific code paths in the existing model implementations. +V0-only classes and code will be removed in the very near future. +The model should also be added to the `MODELS_CONFIG_MAP` dictionary in to ensure that the runtime defaults are optimized. + +For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](gh-file:vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](gh-file:vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together). +These models should follow the same instructions as case (1), but they should inherit protocol `IsHybrid` (instead of `IsAttentionFree`) and it is *not* necessary to add them to the `MODELS_CONFIG_MAP` (their runtime defaults will be inferred from the protocol). + +For case (3), we recommend looking at the implementation of [`MiniMaxText01ForCausalLM`](gh-file:vllm/model_executor/models/minimax_text_01.py) or [`Lfm2ForCausalLM`](gh-file:vllm/model_executor/models/lfm2.py) as a reference, which use custom "mamba-like" layers `MiniMaxText01LinearAttention` and `ShortConv` respectively. +Please follow the same guidelines as case (2) for implementing these models. +We use "mamba-like" to refer to layers that posses a state that is updated in-place, rather than being appended-to (like KV cache for attention). +For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`. +It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers. +Please see [`LinearAttentionMetadata`](gh-file:vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](gh-file:v1/attention/backends/short_conv_attn.py) for examples of this. +Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. +Please see the calls to `direct_register_custom_op` in or for examples of this. +The new custom op should then be added to the list `_attention_ops` in to ensure that piecewise CUDA graphs works as intended. diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 64a48be32645a70ad09d2a8a65b0e7ad7af0fdc8..dc742c8fcf2cd59d5dceffa80ec2445577b6b3c4 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id @@ -855,7 +855,7 @@ Examples: ### Custom HF processor -Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. +Some models don't define an HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. Examples: diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index e62a33b2085cafb8be8c675fb05b9c5213ffa8ff..0b41e73b030cc1f9eb288a789e6c158a7319784c 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 - Download and install [Anything LLM desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Prooviders --> LLM: +- On the bottom left of open settings, AI Providers --> LLM: - LLM Provider: Generic OpenAI - Base URL: http://{vllm server host}:{vllm server port}/v1 - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` diff --git a/docs/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md index 23dc58c974ed809083f7f7fad5b7ab9dc43fbc2e..fe4d87f78f2aaf1b8c3e8f2768e83bb3fabb3dbc 100644 --- a/docs/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/), To install dstack client, run: ```bash -pip install "dstack[all] +pip install dstack[all] dstack server ``` diff --git a/docs/deployment/frameworks/lobe-chat.md b/docs/deployment/frameworks/lobe-chat.md index e3e7dbe6e1e8058ee8b3d357064c56d94d439d19..8ecd1484eab066341b297e94acaf3884574fc83d 100644 --- a/docs/deployment/frameworks/lobe-chat.md +++ b/docs/deployment/frameworks/lobe-chat.md @@ -6,6 +6,6 @@ Supports speech-synthesis, multi-modal, and extensible (function call) plugin sy One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application. -It supports vLLM as a AI model provider to efficiently serve large language models. +It supports vLLM as an AI model provider to efficiently serve large language models. For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm). diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index cad801a4312cce400726911f14a30093d675a1ce..ca23e0b9fd8af1fa209540aef059985e9ef9d5f2 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -380,7 +380,7 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ### Startup Probe or Readiness Probe Failure, container log contains "KeyboardInterrupt: terminated" -If the startup or readiness probe failureThreshold is too low for the time needed to startup the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: +If the startup or readiness probe failureThreshold is too low for the time needed to start up the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: 1. container log contains "KeyboardInterrupt: terminated" 2. `kubectl get events` shows message `Container $NAME failed startup probe, will be restarted` diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 4b917ab408eecf843f468560d41c6c0ea0db7f99..b03483d1c9b217e1fc67e353fc93849f2cdea8e8 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -133,12 +133,12 @@ class FusedMoEModularKernel: Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example, * PplxPrepareAndFinalize type is backed by Pplx All2All kernels, -* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughtput All2All kernels, and +* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and * DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels. #### Step 1: Add an All2All manager -The purpose of the All2All Manager is to setup the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). +The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). #### Step 2: Add a FusedMoEPrepareAndFinalize Type @@ -183,7 +183,7 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking #### maybe_make_prepare_finalize -The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. +The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. Please refer to the implementations in, * `ModelOptNvFp4FusedMoE` @@ -198,7 +198,7 @@ Please refer to the implementations in, * `CompressedTensorsW8A8Fp8MoECutlassMethod` * `Fp8MoEMethod` * `ModelOptNvFp4FusedMoE` -dervied classes. +derived classes. #### init_prepare_finalize @@ -226,7 +226,7 @@ Doing this will add the new implementation to the test suite. The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` -As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked +As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked with incompatible types, the script will error. ### How To Profile diff --git a/docs/design/hybrid_kv_cache_manager.md b/docs/design/hybrid_kv_cache_manager.md new file mode 100644 index 0000000000000000000000000000000000000000..8f17b473adc085729f781410cf5aed4c9a364535 --- /dev/null +++ b/docs/design/hybrid_kv_cache_manager.md @@ -0,0 +1,245 @@ +# Hybrid KV Cache Manager + +!!! warning + This document was written based on commit [458e74](https://github.com/vllm-project/vllm/commit/458e74eb907f96069e6d8a4f3c9f457001fef2ea). This feature is still in its early stage and things may change. + +## What is a hybrid model? + +Many recent "hybrid" LLMs combine multiple attention types within one model. For example: + +1. Sliding window attention (sw) + full attention (full): gpt-oss, Gemma 2/3, Ministral, cohere, etc. +2. Mamba + full: Bamba, Jamba, Minimax, etc. +3. Local chunked attention + full: Llama4 + +To serve these models efficiently, our [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] must: + +1. Allocate different slots to different layer type, for example: + - Full attention layers: reserve slots for **all** tokens. + - Sliding window layers: reserve slots only for the most recent **`sliding_window_size`** tokens. +2. Support layer-specific prefix-cache rules, for example: + - Full attention: a cache hit prefix requires **all** tokens remain in the KV cache. + - Sliding window: a cache hit prefix only requires the last **`sliding_window_size`** tokens remain in the KV cache. + +## Definitions + +1. **kv hidden size**: The number of bytes to store one token's KV cache for a single layer. +2. **block**: the memory reserved for kv cache are divided into multiple *blocks* with the same *page size* (defined below) +3. **block size**: number of tokens inside a block +4. **page size**: the physical memory size of a block, defined as: + + $$ + \text{num_layers} \times \text{block_size} \times \text{kv_hidden_size} + $$ + + `num_layers` doesn't mean the total number of layers in the model. The exact number depends on the context in this doc. + + !!! note + This is different from `KVCacheSpec.page_size_bytes` in the code, which is defined as: + + $$ + \text{block_size} \times \text{kv_hidden_size} + $$ + +## Allocation + +### High level idea + +We use a single memory pool for all layer types. The memory pool is split into multiple blocks with the same page size. [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates different numbers of blocks to different layers according to its attention type. + +The core challenge is ensuring every layer type uses the same **page size**. For full-attention-only models, the page size is straightforward, defined as: + +$$ +\text{page_size} = \text{block_size} \times \text{num_hidden_layers} \times \text{kv_hidden_size} +$$ + +However, in hybrid models, `num_hidden_layers` varies by attention type, which would normally produce mismatched page sizes. The cases below show how we unify them. + +### Case 1: toy model + +Let's start with a toy example: a model has 1 full attention layer and 3 sliding window attention layers. All layers have the same `kv_hidden_size`. + +We let each block to hold `block_size` tokens for one layer, so: + +$$ +\text{page_size} = \text{kv_hidden_size} \times \text{block_size} +$$ + +[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates a different number of blocks to each layer. + +This case is only a toy example. For real models, please refer to the following cases. + +### Case 2: same `kv_hidden_size` and a regular pattern + +When the model has more layers, e.g., 20 sliding window attention layers and 10 full attention layers with the same `kv_hidden_size`. Calling the allocator once per layer (30 calls) is OK but becomes inefficient. As a solution, we group the allocation of layers that need the same number of blocks to reduce the number of calls. + +The grouping is feasible because there is usually a beautiful ratio between the number of different types of layers. For example: + +- Gemma-2: 1 sw : 1 full +- Llama 4: 3 local : 1 full + +Our example can be regarded as 2 sw : 1 full. We can allocate blocks as if there are 2 sw and 1 full in the model, and repeat the result by 10 times to generate the `block_ids` for the 30 layers. The page size becomes: + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +Assume `block_size` 16, sliding window size 32, request length 112, then for the above example model, we need to allocate 11 blocks (0-6 for full, 7-8 for sw group 1, 9-10 for sw group 2). + +![Allocation Result](../assets/design/hybrid_kv_cache_manager/basic_grouping_example.png) + +Here, "/" denotes no block needed (sliding‑window layers don't need slots for early tokens). + +See the formal definition below. The layers are divided into multiple *KV Cache Groups* so that there is: + +1. **Identical attention type inside each group**: Each group only contains layers with the same attention type and thus need the same number of blocks for a given request. This enables layers in the same group share the same block ids without memory waste. +2. **Identical page size across groups**: Because our memory pool only have one page size. + +Our example model is divided into 3 KV cache groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +Obviously, it satisfies rule 1. For rule 2, all 3 groups have + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +as their page size. + +### Case 3: same `kv_hidden_size` and no regular pattern + +Unfortunately, not all models have such a beautiful ratio, and approach in Case 2 will produce too many small groups. For example, Gemma-3-27b has 52 sliding window attention layers and 10 full attention layers. With the constraints in case 2, it would be 26 sliding window groups and 5 full attention groups, each contains 2 layers. The allocation is still inefficient. To reduce the number of kv cache groups, we group layers using the smallest layer count among all attention types. For example, min(52, 10)=10 layers per group in Gemma-3-27b. Then the grouping result is: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) +- ... +- Group 6: 10 sliding window attention layers (sw.40 - sw.49) +- Group 7: 2 sliding window attention layers (sw.50 - sw.51) and 8 padding layers + +We will update this algorithm if this heuristic leads to a bad result when a new model comes out (e.g., 20 full + 30 sw, the group size should be 10 instead of 20). + +This case happens in Gemma-3 series models, and models in case 2 but with eagle speculative decoding which introduce one full attention layer. The solution has some memory waste and is not perfect. Please report any cases where padding overhead becomes unacceptable so we can refine the algorithm. + +### Case 4: different `kv_hidden_size` (mainly hybrid mamba models) + +Some architectures (e.g., Bamba, Jamba, Minimax) interleave standard attention layers with Mamba layers, where each Mamba layer's state size per token can be much larger than the attention layers' `kv_hidden_size`. Because we only support a single page size across all groups, we must reconcile these differing hidden sizes. + +The current algorithm is: + +1. Increase the `block_size` of attention layers until + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \ge \text{state_size}_{\text{mamba}} + $$ +2. Pad the mamba state per layer to + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} + $$ +3. Apply the grouping strategy in case 3. + +!!! note + This can lead to more than 400 `block_size` for attention layers, which is too large. Another padding strategy is to increase `block_size` until + + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \times \text{num_attn_layers} \ge \text{state_size}_{\text{mamba}} + $$ + + This padding strategy is still a work in progress. + +### Case 5: KV sharing + +KV sharing refers to a layer using the KV cache of another layer, e.g., gemma-3n. +In these models, [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] ignores all layers with kv sharing and only allocates KV cache for layers that need kv cache, and some patches are made in model runner to apply the allocation result to kv sharing layers. + +## Prefix caching + +For simplicity, we assume `block_size=1` in this section. + +### High level idea + +The block pool uses a dict similar to `tuple(block_hash, group_id) -> block` to catch the full blocks. That means the same tokens of different groups are cached and evicted independently. + +When a new request comes in, we check the cache hit prefix of each group, and return the intersection of these groups as the cached prefix of the request. See below for the detailed algorithm for checking the cache hit of one group & performing the intersection. + +### Case 0: full attention only models + +For full attention layers, blocks are allocated for all tokens in the request. For details on the underlying design, see [Prefix Caching](prefix_caching.md) + +To find the longest cache hit prefix of a request, we enumerate from left (the first block) to right (the last block), checking whether the block is cached, and exit when cache misses. For example, we will return the first 7 tokens (0-6) as the cache hit prefix in the below example (blue blocks are cached): + +![Prefix Caching of Full Attention](../assets/design/hybrid_kv_cache_manager/full_attn.png) + +### Case 1: sliding window attention only models + +For sliding window attention layers, a naive implementation for memory allocation is to allocate `sliding_window_size` blocks and fill in the blocks in a round-robin way. But this naive implementation is not compatible with prefix caching so we didn't pick this design. In vLLM, we allocate different blocks for different tokens and free blocks that are outside the sliding window. + +For a new request, the cache hit prefix only requires the last `sliding_window_size - 1` tokens being cached. +Let's say `sliding_window_size = 4` and `block_size = 1`, and the request is a 15-token prompt (blue blocks are cached): + +![Prefix Caching of Sliding Window Attention](../assets/design/hybrid_kv_cache_manager/sw_attn.png) + +There are 3 possible cache hit prefixes: + +- cache hit length 5, compute prefill with [2, 3, 4] → [5, 6, …, 14] +- cache hit length 6, compute prefill with [3, 4, 5] → [6, 7, …, 14] +- cache hit length 14, compute prefill with [11, 12, 13] → [14] (most efficient) + +We can check the cache hit from right to left, and early exit when we find a match.This is opposite from full attention, where we check from left to right and early exit when the match fails. One potential cons (compared to full attention) is that we end up iterating over the entire list of tokens when there's no match, which is often a common case. This could potentially cause non-negligible overheads, but fine with full + swa, as discussed below. + +### Case 2: sliding window attention + full attention models + +The first problem is how to find the cache hit prefix. We need to "intersect" the cache hits of global and sliding window attention layers by: + +1. Get the longest cache hit for full attention (scanning from left to right) +2. Get the longest cache hit for sliding window attention that is within that length. Implemented by checking cache hits from right to left starting from the cache hit length of full attention. + +It can be ensured that the resulting cache hit of sliding window attention layers is also a cache hit of full attention layers. This is more efficient than finding all possible prefixes of each group and doing the intersection, because our approach can exit early if there is no cache hit. + +The algorithm applies to models with exactly two attention types full attention + X, where X can be an arbitrary efficient attention algorithm like sliding window, llama 4 local attention, and mamba. It doesn't support models without full attention layers, and models with more than 2 types of attention. This is enough for most hybrid models at the moment of writing this doc. + +The second question is the cache eviction policy. For now, we use one LRU queue for all kv cache groups. The blocks are added to the LRU queue when freed, either because the request is finished or the block is out of the sliding window. + +### Case 3: mamba models + +The prefix caching support of the mamba model is work in progress. Once implemented, models with mamba layer + full attention layer can be supported via the full attention + X algorithm in case 2. + +## Implementation + +### Overview + +![Overview of Hybrid KV Cache Manager](../assets/design/hybrid_kv_cache_manager/overview.png) + +The `KVCacheManager` is organized into 3 layers: + +- **[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager]**: The interface between the scheduler and kv cache management system. +- **[KVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinator]**: coordinate per-group SingleTypeKVCacheManagers to generate the allocation result of a request. Depending on the model's configuration, one of these coordinators is chosen: + - **[KVCacheCoordinatorNoPrefixCache][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinatorNoPrefixCache]**: Used when prefix caching is disabled. + - **[UnitaryKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.UnitaryKVCacheCoordinator]**: If only one KV cache group. The prefix caching logic is simplified as no intersection is needed. + - **[HybridKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.HybridKVCacheCoordinator]**: Handles exactly two KV cache groups (must include one full‑attention group plus one other efficient‑attention group). Other cases are not implemented. You can disable prefix caching to use the KVCacheCoordinatorNoPrefixCache. +- **[SingleTypeKVCacheManager][vllm.v1.core.single_type_kv_cache_manager.SingleTypeKVCacheManager]**: Each instance manages allocation and prefix caching for one KV cache group, implementing the attention‑type–specific logic (e.g., full attention, sliding window, Mamba). + +The blue box in the above figure shows the case with 10 full attention layers and 20 sliding window attention layers, thus: + +- use `HybridKVCacheCoordinator` +- use 1 `FullAttentionManager` and 2 `SlidingWindowManager` for the 3 `KVCacheGroup`s. + +### Memory Layout + +For a model with n `KVCacheGroup`s, each with m layers, we allocate m buffers. Each buffer is shared by n layers, one from each group. + +The following figure is for a model with 10 full attention layers (full.0 - full.9) and 20 sliding window attention layers (sw.0-sw.19). It follows "case 2" in "Allocation" section and is divided into 3 groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +And for a request, we allocate 11 blocks with `block_id` 0-6 to group 0, 7-8 to group 1, and 9-10 to group 2. + +With such an example, the physical memory is divided into 10 buffers (`KVCacheTensor` 0 - `KVCacheTensor` 9). Each buffer is shared by 3 layers (e.g., `KVCacheTensor` 0 is shared by full.0 from group 0, sw.0 from group 1, and sw.10 from group 2) and is divided into pieces with size `block_size * kv_hidden_size`. The KV cache of these 3 attention layers are saved to different pieces of the buffer based on the allocated `block_ids`: + +![Example Memory Layout](../assets/design/hybrid_kv_cache_manager/memory_layout.png) + +!!! note + One logic "block" is mapped to 10 pieces in the 10 buffers of the physical memory. diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md new file mode 100644 index 0000000000000000000000000000000000000000..8e5d5249409c661febd86c7a9b33335796c948cf --- /dev/null +++ b/docs/design/io_processor_plugins.md @@ -0,0 +1,78 @@ +# IO Processor Plugins + +IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output. + +When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggerd via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint. + +## Writing an IO Processor Plugin + +IO Processor plugins implement the `IOProcessor` interface (): + +```python +IOProcessorInput = TypeVar('IOProcessorInput') +IOProcessorOutput = TypeVar('IOProcessorOutput') + +class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + raise NotImplementedError + + async def pre_process_async( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + return self.pre_process(prompt, request_id, **kwargs) + + @abstractmethod + def post_process(self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs) -> IOProcessorOutput: + raise NotImplementedError + + async def post_process_async( + self, + model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + collected_output = [item async for i, item in model_output] + return self.post_process(collected_output, request_id, **kwargs) + + @abstractmethod + def parse_request(self, request: Any) -> IOProcessorInput: + raise NotImplementedError + + @abstractmethod + def output_to_response( + self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + raise NotImplementedError +``` + +The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. +The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. +The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. + +The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is [here](../../vllm/entrypoints/openai/serving_pooling_with_io_plugin.py). + +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our [online](../../examples/online_serving/prithvi_geospatial_mae.py) and [offline](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py) inference examples. + +## Using an IO Processor plugin + +IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded: + +1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode. +2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json). + +The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json). diff --git a/docs/design/metrics.md b/docs/design/metrics.md index b01838883f31e851e6102d2137d9f5a8389097da..90b2fd32f2979421466c6f7efafb1cb50888f53c 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -99,11 +99,11 @@ http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201 ### Multi-process Mode -In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See . +In v0, metrics are collected in the engine core process and we use multiprocess mode to make them available in the API server process. See . ### Built in Python/Process Metrics -The following metrics are supported by default by `prometheus_client`, but they are not exposed when multi-process mode is used: +The following metrics are supported by default by `prometheus_client`, but they are not exposed when multiprocess mode is used: - `python_gc_objects_collected_total` - `python_gc_objects_uncollectable_total` @@ -565,7 +565,7 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_num_emitted_tokens_total` (Counter) There is a PR under review () to add "prompt lookup (ngram)" -seculative decoding to v1. Other techniques will follow. We should +speculative decoding to v1. Other techniques will follow. We should revisit the v0 metrics in this context. !!! note diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 06ebd77258582baed6e6d8bcd659d36325f52aa4..247072d1cb27517bb9fbdff2ecd0ac95848ea054 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. There are other miscellaneous places hard-coding the use of `spawn`: -- +- - Related PRs: diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index fb991a35caf30ba16b3473d56794c8b8fc232e4d..d87b2a639df1260f9dcf9e10f62d9a75452afbbd 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle a whole block of value tokens. And each `accs` in each thread contains 8 elements that accumulated at 8 different head positions. For the thread 0, the `accs` variable will have 8 elements, which -are 0th, 32th … 224th elements of a value head that are accumulated +are 0th, 32nd … 224th elements of a value head that are accumulated from all assigned 8 tokens. ## LV diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index ca1c2c2305d916a989c6ca812b457d689b0653aa..37193809776ae03191f4815e856c18b180ea7c8b 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -49,6 +49,8 @@ Every plugin has three parts: - **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported. +- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name. + ## Guidelines for Writing Plugins - **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes. diff --git a/docs/examples/README.md b/docs/examples/README.md index 34e4dfd408a205e5021fcf63ba5039f26036e81a..3cf93027f4209c0b4cde6486737d6605f462f6b1 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -2,6 +2,6 @@ vLLM's examples are split into three categories: -- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference/) -- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving/) -- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others/) +- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference) +- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving) +- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others) diff --git a/docs/features/lora.md b/docs/features/lora.md index 668460a368a77dfeb14cda93d073887faa4bf504..db794b2ebd71d85eea6331a2f29836b1ebf7258d 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -52,7 +52,7 @@ Check out for an exa ## Serving LoRA Adapters LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use -`--lora-modules {name}={path} {name}={path}` to specify each LoRA module when we kickoff the server: +`--lora-modules {name}={path} {name}={path}` to specify each LoRA module when we kick off the server: ```bash vllm serve meta-llama/Llama-2-7b-hf \ diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 9d51f9cf52f503591bbc6691d1ee6c4fae03fec3..206ab7a4687552cf1196986a31f448c6c35e25c7 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -13,6 +13,41 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: - `prompt`: The prompt should follow the format that is documented on HuggingFace. - `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][]. +### Stable UUIDs for Caching (multi_modal_uuids) + +When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content. + +??? code + + ```python + from vllm import LLM + from PIL import Image + + # Qwen2.5-VL example with two images + llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct") + + prompt = "USER: \nDescribe the differences.\nASSISTANT:" + img_a = Image.open("/path/to/a.jpg") + img_b = Image.open("/path/to/b.jpg") + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": {"image": [img_a, img_b]}, + # Provide stable IDs for caching. + # Requirements (matched by this example): + # - Include every modality present in multi_modal_data. + # - For lists, provide the same number of entries. + # - Use None to fall back to content hashing for that item. + "multi_modal_uuids": {"image": ["sku-1234-a", None]}, + }) + + for o in outputs: + print(o.outputs[0].text) + ``` + +!!! warning + If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored. + ### Image Inputs You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples: diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index e18c128f30fc969ec69f98778ccbbfcbafc2dfc5..4605ba7781ed4a653e80776669ac6b4ec33d4df3 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -4,7 +4,6 @@ Quantization trades off model precision for smaller memory footprint, allowing l Contents: -- [Supported Hardware](supported_hardware.md) - [AutoAWQ](auto_awq.md) - [AutoRound](auto_round.md) - [BitsAndBytes](bnb.md) @@ -19,3 +18,50 @@ Contents: - [AMD Quark](quark.md) - [Quantized KV Cache](quantized_kvcache.md) - [TorchAO](torchao.md) + +## Supported Hardware + +The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: + + + +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | +| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | + +- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. +- ✅︎ indicates that the quantization method is supported on the specified hardware. +- ❌ indicates that the quantization method is not supported on the specified hardware. + +!!! note + This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. + + For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md index 6f53a448ee364120c98e050d83a3308d6fa256ad..53b689ad53ff689b6c9b6380e85c4470145e82d3 100644 --- a/docs/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -5,7 +5,7 @@ vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more effic !!! note Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. - For details see [supported hardware](supported_hardware.md). + For details see [supported hardware](README.md#supported-hardware). Below are the steps to utilize BitBLAS with vLLM. diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index 0661933acd61ff6069610bede8a0feb3eeef3fb4..834c03cbe05b031d5d275687cb0d64bed7f32874 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio Install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` Load and run the model in `vllm`: diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md index 13b151bc7f380d73141efa5e577a889fc33b76b9..5e86e9388f328be3f733d0b4ec815ae9dad255c3 100644 --- a/docs/features/quantization/inc.md +++ b/docs/features/quantization/inc.md @@ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b [Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). !!! note - Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. !!! note `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index 127e403989944fe3c2942ecb5dfae8564800478f..d6fdac7b07f7f98986d9a57b141f6f2baddacb1c 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -18,7 +18,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 45fae58a64868c1e9b1e7f4749d8a9e69dc61cc3..247d0cbdd3f149220ace2b12309b1097ec35c045 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -19,7 +19,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index e8ed2155375d419a846b884de9b37ca774f88154..047cc8382445b9d005fa05e363e9bab8080fe067 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -20,7 +20,7 @@ for more installation details. Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md deleted file mode 100644 index 06264d08b56aa9165a00cc229bae0be23ab92052..0000000000000000000000000000000000000000 --- a/docs/features/quantization/supported_hardware.md +++ /dev/null @@ -1,32 +0,0 @@ -# Supported Hardware - -The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: - - - -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | - -- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. -- ✅︎ indicates that the quantization method is supported on the specified hardware. -- ❌ indicates that the quantization method is not supported on the specified hardware. - -!!! note - This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. - - For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 04b943efbbbb4efc50f32eac690f35b6fc89aaf1..d9a785eb73fbeda852c5266319f2e0914e3fbc07 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -143,7 +143,7 @@ OpenAI Python client library does not officially support `reasoning_content` att print(content, end="", flush=True) ``` -Remember to check whether the `reasoning_content` exists in the response before accessing it. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). +Remember to check whether the `reasoning_content` exists in the response before accessing it. You could check out the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). ## Tool Calling diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 8a934d406f3821a37051745a4ad635da7fe73760..0d6294a5fdd798ade7c0343ce1d0183b3d7fc87f 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -205,7 +205,7 @@ This section covers the OpenAI beta wrapper over the `client.chat.completions.cr At the time of writing (`openai==1.54.4`), this is a "beta" feature in the OpenAI client library. Code reference can be found [here](https://github.com/openai/openai-python/blob/52357cff50bee57ef442e94d78a0de38b4173fc2/src/openai/resources/beta/chat/completions.py#L100-L104). -For the following examples, vLLM was setup using `vllm serve meta-llama/Llama-3.1-8B-Instruct` +For the following examples, vLLM was set up using `vllm serve meta-llama/Llama-3.1-8B-Instruct` Here is a simple example demonstrating how to get structured output using Pydantic models: diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 37d502ef9ce0a5ef944a158136f37a4c30cace50..afc605a504b3df4e188b9058e3a2ce8327c51087 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -284,6 +284,14 @@ Supported models: Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}` +### DeepSeek-V3.1 Models (`deepseek_v31`) + +Supported models: + +* `deepseek-ai/DeepSeek-V3.1` (use with ) + +Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}` + ### Kimi-K2 Models (`kimi_k2`) Supported models: diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index 0ee680f5c688c8a63d7401b477c2ae8dc2e35f88..8a658b7a9103fc3f875cb72e3e8a28863d87346b 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -12,7 +12,6 @@ vLLM supports the following hardware platforms: - [Apple silicon](cpu.md#apple-silicon) - [IBM Z (S390X)](cpu.md#ibm-z-s390x) - [Google TPU](google_tpu.md) -- [Intel Gaudi](intel_gaudi.md) - [AWS Neuron](aws_neuron.md) ## Hardware Plugins diff --git a/docs/getting_started/installation/aws_neuron.md b/docs/getting_started/installation/aws_neuron.md index b8bd76bd5bcbeb1d60afc9a3eeafea3bdb7f514f..ff2500f0352709d2c95a6fcb5d9774bbe45ba15f 100644 --- a/docs/getting_started/installation/aws_neuron.md +++ b/docs/getting_started/installation/aws_neuron.md @@ -140,8 +140,8 @@ Alternatively, users can directly call the NxDI library to trace and compile you - `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the - artifacts under `neuron-compiled-artifacts/{unique_hash}/` sub-directory in the model path. If this environment variable is set, - but the directory does not exist, or the contents are invalid, Neuron will also fallback to a new compilation and store the artifacts + artifacts under `neuron-compiled-artifacts/{unique_hash}/` subdirectory in the model path. If this environment variable is set, + but the directory does not exist, or the contents are invalid, Neuron will also fall back to a new compilation and store the artifacts under this specified path. - `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). - `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 7a34d47d8e494d5bcc1dcffd01349196e954d3ad..7f0ecb2bc0b74201070d8e8b293266e96eb01a7c 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -96,6 +96,7 @@ Currently, there are no pre-built CPU wheels. - `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. +- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. - `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). - `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). @@ -170,7 +171,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -179,7 +180,7 @@ Inference batch size is a important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes. +vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use DP, TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? @@ -190,6 +191,6 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? -- Both of them requires `amx` CPU flag. +- Both of them require `amx` CPU flag. - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 2828173a76a9a2aea3d3370087238c5215127f76..124a41adf1ae2899886d35b6400c194f38819e6a 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -1,6 +1,6 @@ # --8<-- [start:installation] -vLLM has experimental support for macOS with Apple silicon. For now, users must build from source to natively run on macOS. +vLLM has experimental support for macOS with Apple Silicon. For now, users must build from source to natively run on macOS. Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 6dc6f94249c340c45004ce0ec353283689bb0fc1..f7af259ace62806a37ba2731250cab4da6bc7315 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -43,7 +43,7 @@ docker build -f docker/Dockerfile.cpu \ # Launching OpenAI server docker run --rm \ - --privileged=true \ + --security-opt seccomp=unconfined \ --shm-size=4g \ -p 8000:8000 \ -e VLLM_CPU_KVCACHE_SPACE= \ diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 69a9842e4719b82520c1f2f06b931a57ee4c9572..275232e12e08c266ef66851ae2d7c8b76cdc0e10 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -48,7 +48,7 @@ uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VE #### Install the latest code -LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on a x86 platform with CUDA 12 for every commit since `v0.5.3`. +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on an x86 platform with CUDA 12 for every commit since `v0.5.3`. ```bash uv pip install -U vllm \ diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 560883d3caf9e981bfe95ae17cab0aea6714576c..80e99d3034d39fe3c65cb46baed903cc4a725f28 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -149,7 +149,7 @@ Build a docker image from which setup ROCm **This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** If you choose to build this rocm_base image yourself, the steps are as follows. -It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```json { @@ -170,7 +170,7 @@ DOCKER_BUILDKIT=1 docker build \ #### Build an image with vLLM First, build a docker image from and launch a docker container from the image. -It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```bash { diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md deleted file mode 100644 index 61b2b02aa10ba535ee61b1d3e7b2ff21a0f4c3af..0000000000000000000000000000000000000000 --- a/docs/getting_started/installation/intel_gaudi.md +++ /dev/null @@ -1,388 +0,0 @@ -# Intel Gaudi - -This page provides instructions on running vLLM with Intel Gaudi devices. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - -## Requirements - -- OS: Ubuntu 22.04 LTS -- Python: 3.10 -- Intel Gaudi accelerator -- Intel Gaudi software version 1.18.0 - -Please follow the instructions provided in the -[Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) -to set up the execution environment. To achieve the best performance, -please follow the methods outlined in the -[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). - -## Configure a new environment - -### Environment verification - -To verify that the Intel Gaudi software was correctly installed, run: - -```bash -hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible -apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed -pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed -pip list | grep neural # verify that neural_compressor_pt is installed -``` - -Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) -for more details. - -### Run Docker Image - -It is highly recommended to use the latest Docker image from Intel Gaudi -vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) -for more details. - -Use the following commands to run a Docker image: - -```bash -docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --ipc=host \ - vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -``` - -## Set up using Python - -### Pre-built wheels - -Currently, there are no pre-built Intel Gaudi wheels. - -### Build wheel from source - -To build and install vLLM from source, run: - -```bash -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -r requirements/hpu.txt -python setup.py develop -``` - -Currently, the latest features and performance optimizations are developed in Gaudi's [vLLM-fork](https://github.com/HabanaAI/vllm-fork) and we periodically upstream them to vLLM main repo. To install latest [HabanaAI/vLLM-fork](https://github.com/HabanaAI/vllm-fork), run the following: - -```bash -git clone https://github.com/HabanaAI/vllm-fork.git -cd vllm-fork -git checkout habana_main -pip install -r requirements/hpu.txt -python setup.py develop -``` - -## Set up using Docker - -### Pre-built images - -Currently, there are no pre-built Intel Gaudi images. - -### Build image from source - -```bash -docker build -f docker/Dockerfile.hpu -t vllm-hpu-env . -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --rm vllm-hpu-env -``` - -!!! tip - If you're observing the following error: `docker: Error response from daemon: Unknown runtime specified habana.`, please refer to "Install Using Containers" section of [Intel Gaudi Software Stack and Driver Installation](https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html). Make sure you have `habana-container-runtime` package installed and that `habana` container runtime is registered. - -## Extra information - -### Supported features - -- [Offline inference](../../serving/offline_inference.md) -- Online serving via [OpenAI-Compatible Server](../../serving/openai_compatible_server.md) -- HPU autodetection - no need to manually select device within vLLM -- Paged KV cache with algorithms enabled for Intel Gaudi accelerators -- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, - prefill attention, Root Mean Square Layer Normalization, Rotary - Positional Encoding -- Tensor parallelism support for multi-card inference -- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) - for accelerating low-batch latency and throughput -- Attention with Linear Biases (ALiBi) -- INC quantization - -### Unsupported features - -- Beam search -- LoRA adapters -- AWQ quantization -- Prefill chunking (mixed-batch inferencing) - -### Supported configurations - -The following configurations have been validated to function with -Gaudi2 devices. Configurations that are not listed may or may not work. - -| Model | TP Size| dtype | Sampling | -|-------|--------|--------|----------| -| [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) | 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) | 8 | BF16 | Random / Greedy | - -## Performance tuning - -### Execution modes - -Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via `PT_HPU_LAZY_MODE` environment variable), and `--enforce-eager` flag. - -| `PT_HPU_LAZY_MODE` | `enforce_eager` | execution mode | -|----------------------|-------------------|--------------------| -| 0 | 0 | torch.compile | -| 0 | 1 | PyTorch eager mode | -| 1 | 0 | HPU Graphs | - -!!! warning - In 1.18.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode. - -[](){ #gaudi-bucketing-mechanism } - -### Bucketing mechanism - -Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime) is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. -In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - `batch_size` and `sequence_length`. - -!!! note - Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. - -Bucketing ranges are determined with 3 parameters - `min`, `step` and `max`. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: - -```text -INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] -INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] -INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] -INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] -``` - -| Parameter | Description | -|----------------|-----------------------------------------------------------------------------| -| `min` | Determines the lowest value of the bucket. | -| `step` | Determines the interval between buckets. | -| `max` | Determines the upper bound of the bucket. | -| Ramp-up phase | A special handling phase applied between `min` and `step`:
- `min` is multiplied by consecutive powers of two until `step` is reached.
- Minimizes resource wastage for small batch sizes.
- Allows larger padding for larger batches. | - -Example (with ramp-up): - -```text -min = 2, step = 32, max = 64 -=> ramp_up = (2, 4, 8, 16) -=> stable = (32, 64) -=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) -``` - -Example (without ramp-up): - -```text -min = 128, step = 128, max = 512 -=> ramp_up = () -=> stable = (128, 256, 384, 512) -=> buckets = ramp_up + stable => (128, 256, 384, 512) -``` - -In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. - -!!! warning - If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. - -As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as `(4, 512)` prefill bucket, as `batch_size` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as `(4, 512)` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a `(2, 512)` bucket, or context length increases above 512 tokens, in which case it will become `(4, 640)` bucket. - -!!! note - Bucketing is transparent to a client -- padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. - -### Warmup - -Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: - -??? console "Logs" - - ```text - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB - INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB - ... - INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB - INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB - ... - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - ``` - -This example uses the same buckets as in the [Bucketing Mechanism][gaudi-bucketing-mechanism] section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. - -!!! tip - Compiling all the buckets might take some time and can be turned off with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. - -### HPU Graph capture - -[HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. - -When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by `gpu_memory_utilization` flag (`0.9` by default). -Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. -Only after that, `gpu_memory_utilization` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. -Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. -Environment variable `VLLM_GRAPH_RESERVED_MEM` defines the ratio of memory reserved for HPU Graphs capture. -With its default value (`VLLM_GRAPH_RESERVED_MEM=0.1`), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache. -Environment variable `VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (`VLLM_GRAPH_PROMPT_RATIO=0.3`), both stages have equal memory constraints. -Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. - -!!! note - `gpu_memory_utilization` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, `gpu_memory_utilization` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. - -User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: - -- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode -- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt - -When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. - -!!! note - `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. - -Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): - -??? console "Logs" - - ```text - INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] - INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] - INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] - INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache - INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 - INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB - ... - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - ... - INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB - INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB - ... - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory - INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) - ``` - -### Recommended vLLM Parameters - -- We recommend running inference on Gaudi 2 with `block_size` of 128 - for BF16 data type. Using default values (16, 32) might lead to - sub-optimal performance due to Matrix Multiplication Engine - under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). -- For max throughput on Llama 7B, we recommend running with batch size - of 128 or 256 and max context length of 2048 with HPU Graphs enabled. - If you encounter out-of-memory issues, see troubleshooting section. - -### Environment variables - -**Diagnostic and profiling knobs:** - -- `VLLM_PROFILER_ENABLED`: If `true`, enable the high level profiler. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: If `true`, log graph compilations for each vLLM engine step when any occurs. Highly recommended to use with `PT_HPU_METRICS_GC_DETAILS=1`. `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: If `true`, always log graph compilations for each vLLM engine step even if none occurred. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: If `true`, log CPU fallbacks for each vLLM engine step when any occurs. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, always log CPU fallbacks for each vLLM engine step even if none occurred. `false` by default. - -**Performance tuning knobs:** - -- `VLLM_SKIP_WARMUP`: if `true`, warmup will be skipped, `false` by default - -- `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for HPUGraph capture, `0.1` by default - -- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory dedicated for prompt graphs, `0.3` by default - -- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`, `min_tokens` by default - -- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode graph capture, `min_tokens` or `max_bs`, `max_bs` by default - -- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - - - `{phase}` is either `PROMPT` or `DECODE` - - - `{dim}` is either `BS`, `SEQ` or `BLOCK` - - - `{param}` is either `MIN`, `STEP` or `MAX` - - - Default values: - -| `{phase}` | Parameter | Env Variable | Value Expression | -|-----------|-----------|--------------|------------------| -| Prompt | Batch size min | `VLLM_PROMPT_BS_BUCKET_MIN` | `1` | -| Prompt | Batch size step | `VLLM_PROMPT_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Prompt | Batch size max | `VLLM_PROMPT_BS_BUCKET_MAX` | `min(max_num_seqs, 64)` | -| Prompt | Sequence length min | `VLLM_PROMPT_SEQ_BUCKET_MIN` | `block_size` | -| Prompt | Sequence length step | `VLLM_PROMPT_SEQ_BUCKET_STEP` | `block_size` | -| Prompt | Sequence length max | `VLLM_PROMPT_SEQ_BUCKET_MAX` | `max_model_len` | -| Decode | Batch size min | `VLLM_DECODE_BS_BUCKET_MIN` | `1` | -| Decode | Batch size step | `VLLM_DECODE_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Decode | Batch size max | `VLLM_DECODE_BS_BUCKET_MAX` | `max_num_seqs` | -| Decode | Sequence length min | `VLLM_DECODE_BLOCK_BUCKET_MIN` | `block_size` | -| Decode | Sequence length step | `VLLM_DECODE_BLOCK_BUCKET_STEP` | `block_size` | -| Decode | Sequence length max | `VLLM_DECODE_BLOCK_BUCKET_MAX` | `max(128, (max_num_seqs*max_model_len)/block_size)` | - -Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: - -- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used; if `1`, PyTorch Lazy backend for Gaudi will be used. `1` is default. -- `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs - -## Troubleshooting: tweaking HPU graphs - -If you experience device out-of-memory issues or want to attempt -inference at higher batch sizes, try tweaking HPU Graphs by following -the below: - -- Tweak `gpu_memory_utilization` knob. It will decrease the - allocation of KV cache, leaving some headroom for capturing graphs - with larger batch size. By default `gpu_memory_utilization` is set - to 0.9. It attempts to allocate ~90% of HBM left for KV cache after - short profiling run. Note that decreasing reduces the number of KV - cache blocks you have available, and therefore reduces the effective - maximum number of tokens you can handle at a given time. -- If this method is not efficient, you can disable `HPUGraph` - completely. With HPU Graphs disabled, you are trading latency and - throughput at lower batches for potentially higher throughput on - higher batches. You can do that by adding `--enforce-eager` flag to - server (for online serving), or by passing `enforce_eager=True` - argument to LLM constructor (for offline inference). diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index f833807666460268dab448ad4d6b153dad93b8eb..2af26626d207ded68d9025c1c9067babd8335243 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -8,7 +8,7 @@ This guide will help you quickly get started with vLLM to perform: ## Prerequisites - OS: Linux -- Python: 3.9 -- 3.12 +- Python: 3.9 -- 3.13 ## Installation diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index ed5d3b0092ae70e3b619b705e0e5eabd53da9a70..051a2d904406da2b3b1dda3c084d5a460beb6b25 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import logging import sys from argparse import SUPPRESS, HelpFormatter @@ -7,25 +8,52 @@ from pathlib import Path from typing import Literal from unittest.mock import MagicMock, patch +from pydantic_core import core_schema + +logger = logging.getLogger("mkdocs") + ROOT_DIR = Path(__file__).parent.parent.parent.parent ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse" sys.path.insert(0, str(ROOT_DIR)) -sys.modules["aiohttp"] = MagicMock() -sys.modules["blake3"] = MagicMock() sys.modules["vllm._C"] = MagicMock() -from vllm.benchmarks import latency # noqa: E402 -from vllm.benchmarks import serve # noqa: E402 -from vllm.benchmarks import throughput # noqa: E402 -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 -from vllm.entrypoints.cli.openai import ChatCommand # noqa: E402 -from vllm.entrypoints.cli.openai import CompleteCommand # noqa: E402 -from vllm.entrypoints.openai import cli_args # noqa: E402 -from vllm.entrypoints.openai import run_batch # noqa: E402 -from vllm.utils import FlexibleArgumentParser # noqa: E402 -logger = logging.getLogger("mkdocs") +class PydanticMagicMock(MagicMock): + """`MagicMock` that's able to generate pydantic-core schemas.""" + + def __get_pydantic_core_schema__(self, source_type, handler): + return core_schema.any_schema() + + +def auto_mock(module, attr, max_mocks=50): + """Function that automatically mocks missing modules during imports.""" + logger.info("Importing %s from %s", attr, module) + for _ in range(max_mocks): + try: + # First treat attr as an attr, then as a submodule + return getattr(importlib.import_module(module), attr, + importlib.import_module(f"{module}.{attr}")) + except importlib.metadata.PackageNotFoundError as e: + raise e + except ModuleNotFoundError as e: + logger.info("Mocking %s for argparse doc generation", e.name) + sys.modules[e.name] = PydanticMagicMock() + + raise ImportError( + f"Failed to import {module}.{attr} after mocking {max_mocks} imports") + + +latency = auto_mock("vllm.benchmarks", "latency") +serve = auto_mock("vllm.benchmarks", "serve") +throughput = auto_mock("vllm.benchmarks", "throughput") +AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs") +EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs") +ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand") +CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand") +cli_args = auto_mock("vllm.entrypoints.openai", "cli_args") +run_batch = auto_mock("vllm.entrypoints.openai", "run_batch") +FlexibleArgumentParser = auto_mock("vllm.utils", "FlexibleArgumentParser") class MarkdownFormatter(HelpFormatter): diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 1e8b848db46d86381b85783fc3766d991c0bea8d..881df791698e26b2ebb068fa2d2672160099769c 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -70,6 +70,10 @@ class Example: self.other_files = self.determine_other_files() self.title = self.determine_title() + @property + def is_code(self) -> bool: + return self.main_file.suffix != ".md" + def determine_main_file(self) -> Path: """ Determines the main file in the given path. @@ -101,6 +105,12 @@ class Example: return [file for file in self.path.rglob("*") if is_other_file(file)] def determine_title(self) -> str: + if not self.is_code: + with open(self.main_file) as f: + first_line = f.readline().strip() + match = re.match(r'^#\s+(?P.+)$', first_line) + if match: + return match.group('title') return fix_case(self.path.stem.replace("_", " ").title()) def generate(self) -> str: @@ -110,11 +120,13 @@ class Example: # Use long code fence to avoid issues with # included files containing code fences too code_fence = "``````" - is_code = self.main_file.suffix != ".md" - if is_code: + # Skip the title from md snippets as it's been included above + start_line = 2 + if self.is_code: content += f"{code_fence}{self.main_file.suffix[1:]}\n" - content += f'--8<-- "{self.main_file}"\n' - if is_code: + start_line = 1 + content += f'--8<-- "{self.main_file}:{start_line}"\n' + if self.is_code: content += f"{code_fence}\n" content += "\n" diff --git a/docs/mkdocs/javascript/mathjax.js b/docs/mkdocs/javascript/mathjax.js new file mode 100644 index 0000000000000000000000000000000000000000..5da0d443578c42a978d7bd1d1194d6f49b8d718e --- /dev/null +++ b/docs/mkdocs/javascript/mathjax.js @@ -0,0 +1,20 @@ +// Enables MathJax rendering +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() +}) diff --git a/docs/models/generative_models.md b/docs/models/generative_models.md index a64ecd31ebaefe9476193ee0092e5bce3c647861..d02522a6657de0433cc6f7e925e911754b76db1a 100644 --- a/docs/models/generative_models.md +++ b/docs/models/generative_models.md @@ -19,7 +19,7 @@ Run a model in generation mode via the option `--runner generate`. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration](../api/summary.md#configuration) for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.generate` diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 39f209d0eb7ed40e71ee490f80f0e57fd552a89c..d2fbb1870dde60201f18ecace9d2ebdc942df66b 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -81,7 +81,7 @@ which takes priority over both the model's and Sentence Transformers's defaults. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration](../api/summary.md#configuration) for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.embed` @@ -205,12 +205,12 @@ Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions. -For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). +For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf-overrides '{"is_matryoshka": true}'`, `--hf-overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). Here is an example to serve a model with Matryoshka Embeddings enabled. ```text -vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}' +vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}' ``` ### Offline Inference @@ -258,4 +258,4 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> +An OpenAI client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a514572945c3f44c24ade9f467c28ef1ce9a9bf5..e8fe77e8d6c98423ece732ec2afac78c8e0ae21a 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -40,7 +40,7 @@ If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it mean #### Custom models -If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! +If a model is neither supported natively by vLLM nor Transformers, it can still be used in vLLM! For a model to be compatible with the Transformers backend for vLLM it must: @@ -328,16 +328,16 @@ th { | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | | `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | @@ -358,12 +358,12 @@ th { | `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | -| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | | ✅︎ | +| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | @@ -373,6 +373,7 @@ th { | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | @@ -384,8 +385,8 @@ th { | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | @@ -400,6 +401,7 @@ th { | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | | `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -436,17 +438,17 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | -| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | | -| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | -| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | -| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | | +| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | +| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | +| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | +| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | +| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ | +| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ | | `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | <sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) @@ -476,7 +478,7 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | @@ -493,12 +495,13 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | | +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) @@ -511,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' ``` +!!! note + The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. + !!! note Load the official original `mxbai-rerank-v2` by using the following command. @@ -613,6 +619,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | @@ -624,9 +632,10 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | | `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | @@ -634,13 +643,14 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -651,6 +661,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ | @@ -696,7 +707,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. !!! note - Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. + For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently. !!! note To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. diff --git a/docs/usage/usage_stats.md b/docs/usage/usage_stats.md index e78c67522f61b7ae82226ca55c0abd81c38c75d3..4c7a7ff019e8ca02240694bbe9eeba1df73dc7e9 100644 --- a/docs/usage/usage_stats.md +++ b/docs/usage/usage_stats.md @@ -51,7 +51,7 @@ tail ~/.config/vllm/usage_stats.json ## Opting out -You can opt-out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: +You can opt out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: ```bash # Any of the following methods can disable usage stats collection diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 54af970ea842de2bf8a58bd82190d05122850ab1..f71805436a6ae064cace4e5f3beae74ba0b9d8a9 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -107,15 +107,14 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported. -Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that -these models currently require disabling prefix caching and using the FlashInfer attention backend in V1. +Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). -Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`). -Please note that these models currently require disabling prefix caching, enforcing eager mode, and using the FlashInfer -attention backend in V1. +Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`). + +Please note that prefix caching is not yet supported for any of the above models. #### Encoder-Decoder Models @@ -154,16 +153,19 @@ differences compared to V0: ##### Logprobs Calculation -Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. +By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling. -Support for logprobs with post-sampling adjustments is in progress and will be added in future updates. +You can adjust this behavior by setting the `--logprobs-mode` flag. +Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`. +Raw means the values before applying any logit processors, like bad words. +Processed means the values after applying all processors, including temperature and top_k/top_p. ##### Prompt Logprobs with Prefix Caching -Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414). +Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs. #### Deprecated Features diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ba27cd1e027d5f01560e07ab0e9161129ff35d --- /dev/null +++ b/examples/offline_inference/dolphin.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import regex as re +from PIL import Image +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +@dataclass +class ImageDimensions: + original_w: int + original_h: int + padded_w: int + padded_h: int + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def map_to_original_coordinates( + x1, y1, x2, y2, dims: ImageDimensions +) -> tuple[int, int, int, int]: + try: + top = (dims.padded_h - dims.original_h) // 2 + left = (dims.padded_w - dims.original_w) // 2 + orig_x1 = max(0, x1 - left) + orig_y1 = max(0, y1 - top) + orig_x2 = min(dims.original_w, x2 - left) + orig_y2 = min(dims.original_h, y2 - top) + if orig_x2 <= orig_x1: + orig_x2 = min(orig_x1 + 1, dims.original_w) + if orig_y2 <= orig_y1: + orig_y2 = min(orig_y1 + 1, dims.original_h) + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) + except Exception as e: + print(f"map_to_original_coordinates error: {str(e)}") + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): + if isinstance(image, str): + image = cv2.imread(image) + img_h, img_w = image.shape[:2] + new_boxes = [] + for box in boxes: + best_box = copy.deepcopy(box) + + def check_edge(img, current_box, i, is_vertical): + edge = current_box[i] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + if is_vertical: + line = binary[current_box[1] : current_box[3] + 1, edge] + else: + line = binary[edge, current_box[0] : current_box[2] + 1] + transitions = np.abs(np.diff(line)) + return np.sum(transitions) / len(transitions) + + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] + current_box = copy.deepcopy(box) + current_box[0] = min(max(current_box[0], 0), img_w - 1) + current_box[1] = min(max(current_box[1], 0), img_h - 1) + current_box[2] = min(max(current_box[2], 0), img_w - 1) + current_box[3] = min(max(current_box[3], 0), img_h - 1) + + for i, direction, is_vertical in edges: + best_score = check_edge(image, current_box, i, is_vertical) + if best_score <= threshold: + continue + for step in range(max_pixels): + current_box[i] += direction + if i == 0 or i == 2: + current_box[i] = min(max(current_box[i], 0), img_w - 1) + else: + current_box[i] = min(max(current_box[i], 0), img_h - 1) + score = check_edge(image, current_box, i, is_vertical) + if score < best_score: + best_score = score + best_box = copy.deepcopy(current_box) + if score <= threshold: + break + new_boxes.append(best_box) + return new_boxes + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): + try: + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) + x1, y1, x2, y2 = new_boxes[0] + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + if previous_box is not None: + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): + y1 = prev_y2 + y1 = min(y1, dims.padded_h - 1) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_previous_box = [x1, y1, x2, y2] + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( + x1, y1, x2, y2, dims + ) + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box + except Exception as e: + print(f"process_coordinates error: {str(e)}") + orig_x1, orig_y1, orig_x2, orig_y2 = ( + 0, + 0, + min(100, dims.original_w), + min(100, dims.original_h), + ) + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: + try: + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + original_h, original_w = image_cv.shape[:2] + max_size = max(original_h, original_w) + top = (max_size - original_h) // 2 + bottom = max_size - original_h - top + left = (max_size - original_w) // 2 + right = max_size - original_w - left + padded_image = cv2.copyMakeBorder( + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + padded_h, padded_w = padded_image.shape[:2] + dimensions = ImageDimensions( + original_w=original_w, + original_h=original_h, + padded_w=padded_w, + padded_h=padded_h, + ) + return padded_image, dimensions + except Exception as e: + print(f"prepare_image error: {str(e)}") + h, w = image.height, image.width + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) + return np.zeros((h, w, 3), dtype=np.uint8), dimensions + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def parse_layout_string(bbox_str): + """Parse layout string using regular expressions""" + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" + matches = re.finditer(pattern, bbox_str) + + parsed_results = [] + for match in matches: + coords = [float(match.group(i)) for i in range(1, 5)] + label = match.group(5).strip() + parsed_results.append((coords, label)) + + return parsed_results + + +model_id = "ByteDance/Dolphin" + +# The input image size for Dolphin is 896 x 896, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 896 / 4 = 224 patches +# Width: 896 / 4 = 224 patches + +# The Dolphin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `<pad>` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. +encoder_prompt = "".join(["0"] * 783) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, +) + +processor = DonutProcessor.from_pretrained(model_id) +llm = LLM( + model=model_id, + dtype="float16", + max_num_seqs=8, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--image_path", type=str, default=None, help="Path to a local image file." +) +args = parser.parse_args() + +if args.image_path: + if not os.path.exists(args.image_path): + raise FileNotFoundError(f"Error: File not found at {args.image_path}") + image = Image.open(args.image_path).convert("RGB") +else: + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) + + +prompt = "Parse the reading order of this document. " +decoder_prompt = f"<s>{prompt}<Answer/>" +decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] +) +enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=decoder_prompt_tokens, +) +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) +layout_result_str = layout_outputs[0].outputs[0].text +print(f"Layout analysis output:\n{layout_result_str}") + +padded_image, dims = prepare_image(image) +layout_results = parse_layout_string(layout_result_str) +text_table_elements = [] +previous_box = None +reading_order = 0 +for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image. " + if label == "tab" + else "Read text in the image. " + ) + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } + ) + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue + +if text_table_elements: + batch_prompts = [] + for elem in text_table_elements: + decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + +print("------" * 8) +text_table_elements.sort(key=lambda x: x["reading_order"]) +for elem in text_table_elements: + print(elem.get("text", "")) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902edb7e7622e637fc873618c093a5a619be..655f9f3fce7ae23c87938c0c6f681f7927f07ba9 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -13,6 +13,7 @@ from typing import NamedTuple from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] +def run_donut(): + engine_args = EngineArgs( + model="naver-clova-ix/donut-base-finetuned-docvqa", + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="float16", + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, + ) + + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, + # and the patch_size is 4 x 4. + # Therefore, the initial number of patches is: + # Height: 1920 / 4 = 480 patches + # Width: 2560 / 4 = 640 patches + # The Swin model uses a staged downsampling approach, + # defined by the "depths": [2, 2, 14, 2] configuration. + # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, + # which halves the feature map's dimensions (dividing both height and width by 2). + # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. + # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. + # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + # Because vLLM needs to fill the image features with an encoder_prompt, + # and the encoder_prompt will have `<pad>` tokens added when tokenized, + # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. + prompts = [ + { + "encoder_prompt": { + "prompt": "".join(["$"] * 4799), + "multi_modal_data": { + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) # noqa: E501 + }, + }, + "decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501 + }, + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", @@ -118,6 +163,7 @@ def run_whisper(): model_example_map = { + "donut": run_donut, "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor.py index 7ef20efa7d28cc267765486fd7941b1e784f1214..3e122319169ebc6bff9d24fe41fc0090725e86ff 100644 --- a/examples/offline_inference/logits_processor.py +++ b/examples/offline_inference/logits_processor.py @@ -42,8 +42,8 @@ from vllm.config import VllmConfig from vllm.v1.sample.logits_processor import ( BatchUpdate, LogitsProcessor, - MoveDirectionality, ) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates # Hypothetical custom logits processor @@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): - self.req_info: dict[int, SamplingParams] = {} + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and ( - target_token := params.extra_args.get("target_token") - ): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + process_dict_updates( + self.req_info, + batch_update, + # This function returns the LP's per-request state based on the + # request details, or None if this LP does not apply to the + # request. + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..8023cd6677762762b72e781894d989f51fcbd6fc --- /dev/null +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import os + +import torch + +from vllm import LLM +from vllm.pooling_params import PoolingParams + +# This example shows how to perform an offline inference that generates +# multimodal data. In this specific case this example will take a geotiff +# image as input, process it using the multimodal data processor, and +# perform inference. +# Reuirement - install plugin at: +# https://github.com/christian-pinto/prithvi_io_processor_plugin + + +def main(): + torch.set_default_dtype(torch.float16) + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + llm = LLM( + model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM. + # The maximum number depends on the available GPU memory + max_num_seqs=32, + io_processor_plugin="prithvi_to_tiff_india", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + pooler_output = llm.encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + print(output) + decoded_data = base64.b64decode(output.data) + + file_path = os.path.join(os.getcwd(), "offline_prediction.tiff") + with open(file_path, "wb") as f: + f.write(decoded_data) + + print(f"Output file path: {file_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891eca7a600fc7b96ca345ec83d8f37037..5af232cb6af6a133adffbb69e1fb2f4c16a89f03 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt from vllm.v1.metrics.reader import Counter, Vector try: @@ -137,7 +138,8 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - prompt_token_ids=prompt_ids, sampling_params=sampling_params + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], + sampling_params=sampling_params, ) else: outputs = llm.chat(prompts, sampling_params=sampling_params) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index f46064931dbac62cc4f83a8c28b8917f72aa4d2f..88d87beb4874de7335615b28ff792e2b7a426a0d 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -85,7 +85,7 @@ def format_output(title: str, output: str): def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) return outputs[0].outputs[0].text diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 988ad35cdd7e6e43982729ae146a38e765b4d7aa..4e879666f61d7cc98d986ab1ea429c7691f780d2 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: <think></think>" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Florence2 def run_florence2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -283,8 +314,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + ( + "<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>" + f"{question}<|assistant|>" + ) for question in questions ] @@ -333,6 +366,80 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.5V +def run_glm4_5v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# GLM-4.5V-FP8 +def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -383,8 +490,8 @@ def run_hyperclovax_seed_vision( for question in questions: if modality == "image": """ - ocr: List the words in the image in raster order. - Even if the word order feels unnatural for reading, + ocr: List the words in the image in raster order. + Even if the word order feels unnatural for reading, the model will handle it as long as it follows raster order. e.g. "Naver, CLOVA, bigshane" lens_keywords: List the entity names in the image. @@ -693,15 +800,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData: if modality == "video": prompts = [ - f"<|im_start|>user <video>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <video>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] elif modality == "image": prompts = [ - f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] @@ -815,6 +920,39 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData: return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6") +def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "MiniMaxAI/MiniMax-VL-01" + + engine_args = EngineArgs( + model=model_name, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + tensor_parallel_size=8, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [ + [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": question}], + } + ] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Mistral-3 HF-format def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -891,8 +1029,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] @@ -998,6 +1135,38 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData: ) +# Ovis2_5 +def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2.5-2B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={modality: 1}, + ) + if modality == "image": + placeholder = "<image>" + elif modality == "video": + placeholder = "<video>" + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"{placeholder}\n{question}"}] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # PaliGemma def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1297,6 +1466,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# R-4B +def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "YannQi/R-4B" + + prompts = [ + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" + for question in questions + ] + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + limit_mm_per_prompt={modality: 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1442,12 +1633,15 @@ model_example_map = { "chameleon": run_chameleon, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "ernie45_vl": run_ernie45_vl, "florence2": run_florence2, "fuyu": run_fuyu, "gemma3": run_gemma3, "gemma3n": run_gemma3n, "glm4v": run_glm4v, "glm4_1v": run_glm4_1v, + "glm4_5v": run_glm4_5v, + "glm4_5v_fp8": run_glm4_5v_fp8, "h2ovl_chat": run_h2ovl, "hyperclovax_seed_vision": run_hyperclovax_seed_vision, "idefics3": run_idefics3, @@ -1463,12 +1657,14 @@ model_example_map = { "mantis": run_mantis, "minicpmo": run_minicpmo, "minicpmv": run_minicpmv, + "minimax_vl_01": run_minimax_vl_01, "mistral3": run_mistral3, "mllama": run_mllama, "molmo": run_molmo, "nemotron_vl": run_nemotron_vl, "NVLM_D": run_nvlm_d, "ovis": run_ovis, + "ovis2_5": run_ovis2_5, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, @@ -1479,6 +1675,7 @@ model_example_map = { "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "rvl": run_r_vl, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, "step3": run_step3, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 799337ed68503cee0e51509b5d6ffdd3d4e15449..d9242efa85470aaf8186d51a739f2471e0b59013 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -680,6 +680,36 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: ) +# ovis2_5 +def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2.5-2B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistral-community/pixtral-12b" @@ -962,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "YannQi/R-4B" + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" @@ -1064,6 +1127,76 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData: ) +# GLM-4.5V +def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enforce_eager=True, + tensor_parallel_size=4, + ) + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + processor = AutoProcessor.from_pretrained(model_name) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + +# GLM-4.5V-FP8 +def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enforce_eager=True, + tensor_parallel_size=4, + ) + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + processor = AutoProcessor.from_pretrained(model_name) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + model_example_map = { "aria": load_aria, "aya_vision": load_aya_vision, @@ -1085,6 +1218,7 @@ model_example_map = { "mllama": load_mllama, "NVLM_D": load_nvlm_d, "ovis": load_ovis, + "ovis2_5": load_ovis2_5, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "phi4_multimodal": load_phi4_multimodal, @@ -1092,10 +1226,13 @@ model_example_map = { "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, "qwen2_5_vl": load_qwen2_5_vl, + "rvl": load_r_vl, "smolvlm": load_smolvlm, "step3": load_step3, "tarsier": load_tarsier, "tarsier2": load_tarsier2, + "glm4_5v": load_glm4_5v, + "glm4_5v_fp8": load_glm4_5v_fp8, } diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 584db53db4e400044263113f516909fd9464d651..f238c66234dcca620059806461e1e33d56a1aad5 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent): token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd34f461362cdf012cf1950e7f7d36cd0d8eeb9 --- /dev/null +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import os + +import requests + +# This example shows how to perform an online inference that generates +# multimodal data. In this specific case this example will take a geotiff +# image as input, process it using the multimodal data processor, and +# perform inference. +# Reuirements : +# - install plugin at: +# https://github.com/christian-pinto/prithvi_io_processor_plugin +# - start vllm in serving mode with the below args +# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' +# --task embed --trust-remote-code +# --skip-tokenizer-init --enforce-eager +# --io-processor-plugin prithvi_to_tiff_india + + +def main(): + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + server_endpoint = "http://localhost:8000/pooling" + + request_payload_url = { + "data": { + "data": image_url, + "data_format": "url", + "image_format": "tiff", + "out_data_format": "b64_json", + }, + "priority": 0, + "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + } + + ret = requests.post(server_endpoint, json=request_payload_url) + + print(f"response.status_code: {ret.status_code}") + print(f"response.reason:{ret.reason}") + + response = ret.json() + + decoded_image = base64.b64decode(response["data"]["data"]) + + out_path = os.path.join(os.getcwd(), "online_prediction.tiff") + + with open(out_path, "wb") as f: + f.write(decoded_image) + + +if __name__ == "__main__": + main() diff --git a/examples/tool_chat_template_deepseekv31.jinja b/examples/tool_chat_template_deepseekv31.jinja new file mode 100644 index 0000000000000000000000000000000000000000..863be69d60b68368662d68f3f584ff47a5b4b27b --- /dev/null +++ b/examples/tool_chat_template_deepseekv31.jinja @@ -0,0 +1,91 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} +{% if not thinking is defined %} + {% set thinking = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{% if tools is defined and tools is not none %} + {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} + {% for tool in tools %} + {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} + {% endfor %} + {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} +{% endif %} + +{{ bos_token }}{{ ns.system_prompt }} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content']}} + {%- endif %} + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- if ns.is_last_user %} + {{'<|Assistant|></think>'}} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- if message['prefix'] is defined and message['prefix'] and thinking %} + {{'<think>'}} + {%- else %} + {{'</think>'}} + {%- endif %} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {%- set content = message['content'] -%} + {%- if '</think>' in content %} + {%- set content = content.split('</think>', 1)[1] -%} + {%- endif %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} +{%- endfor -%} +{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} + {%- if not thinking %} + {{'</think>'}} + {%- else %} + {{'<think>'}} + {%- endif %} +{% endif %} diff --git a/examples/tool_chat_template_gemma3_pythonic.jinja b/examples/tool_chat_template_gemma3_pythonic.jinja new file mode 100644 index 0000000000000000000000000000000000000000..5a20b0191129525c95122b60a99d64cc0ba06d88 --- /dev/null +++ b/examples/tool_chat_template_gemma3_pythonic.jinja @@ -0,0 +1,123 @@ +{#- Begin-of-sequence token to start the model prompt -#} +{{ bos_token }} +{#- Extracts the system message. Gemma does not support system messages so it will be prepended to first user message. -#} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{#- Set tools to none if not defined for this ChatCompletion request (helps avoid errors later) -#} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{#- Validate alternating user/assistant messages (excluding 'tool' messages and ones with tool_calls) -#} +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | selectattr("tool_calls", "undefined") -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} +{%- endfor -%} + +{#- Main loop over all messages in the conversation history -#} +{%- for message in loop_messages -%} + {#- Normalize roles for model prompt formatting -#} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- elif (message['role'] == 'tool') -%} + {%- set role = "user" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {#- Mark the start of a message block with the appropriate role -#} + {{ '<start_of_turn>' + role + '\n' -}} + + {#- Insert system message content (if present) at the beginning of the first message. -#} + {%- if loop.first -%} + {{ first_user_prefix }} + {#- Append system message with tool information if using tools in message request. -#} + {%- if tools is not none -%} + {{- "Tools (functions) are available. If you decide to invoke one or more of the tools, you must respond with a python list of the function calls.\n" -}} + {{- "Example Format: [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] \n" -}} + {{- "Do not use variables. DO NOT USE MARKDOWN SYNTAX. You SHOULD NOT include any other text in the response if you call a function. If none of the functions can be used, point it out. If you lack the parameters required by the function, also point it out.\n" -}} + {{- "Here is a list of functions in JSON format that you can invoke.\n" -}} + {{- tools | tojson(indent=4) -}} + {{- "\n\n" -}} + {%- endif -%} + {%- endif -%} + + {#- Format model tool calls (turns where model indicates they want to call a tool) -#} + {%- if 'tool_calls' in message -%} + {#- Opening bracket for tool call list. -#} + {{- '[' -}} + {#- For each tool call -#} + {%- for tool_call in message.tool_calls -%} + {#- Get tool call function. -#} + {%- if tool_call.function is defined -%} + {%- set tool_call = tool_call.function -%} + {%- endif -%} + {#- Function name & opening parenthesis. -#} + {{- tool_call.name + '(' -}} + + {#-- Handle arguments as list (positional) or dict (named) --#} + {#-- Named arguments (dict) --#} + {%- if tool_call.arguments is iterable and tool_call.arguments is mapping -%} + {%- set first = true -%} + {%- for key, val in tool_call.arguments.items() -%} + {%- if not first %}, {% endif -%} + {{ key }}={{ val | tojson }} + {%- set first = false -%} + {%- endfor -%} + {#-- Positional arguments (list) --#} + {%- elif tool_call.arguments is iterable -%} + {{- tool_call.arguments | map('tojson') | join(', ') -}} + {#-- Fallback: single positional value --#} + {%- else -%} + {{- tool_call.arguments | tojson -}} + {#-- Closing parenthesis. --#} + {%- endif -%} + {{- ')' -}} + {#-- If more than one tool call, place comma and move to formatting next tool call --#} + {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + {#- Closing bracket for tool call list. -#} + {{- ']' -}} + {%- endif -%} + + {#- Tool response start tag (for messages from a tool) -#} + {%- if (message['role'] == 'tool') -%} + {{ '<tool_response>\n' -}} + {%- endif -%} + + {#- Render the message content: handle plain string or multimodal content like image/text -#} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '<start_of_image>' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + + {#- Tool response end tag -#} + {%- if (message['role'] == 'tool') -%} + {{ '</tool_response>' -}} + {%- endif -%} + + {#- Mark end of a single turn -#} + {{ '<end_of_turn>\n' }} +{%- endfor -%} + +{#- If generation is to be triggered, add model prompt prefix -#} +{%- if add_generation_prompt -%} + {{'<start_of_turn>model\n'}} +{%- endif -%} \ No newline at end of file diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja index 36423b6c4240a9ae3e7d15c998ce6ec3ebfafc5d..83886762c2893e2e8bbdb8401c7a4f9abafc97a8 100644 --- a/examples/tool_chat_template_phi4_mini.jinja +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -1,10 +1,14 @@ +{%- if messages and messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant." %} +{%- endif %} + {%- if messages %} - {%- if system_message or tools %} <|system|> - -{%- if system_message %} {{ system_message }} -{%- endif %} +{%- if tools %} In addition to plain text responses, you can chose to call one or more of the provided functions. Use the following rule to decide when to call a function: @@ -19,13 +23,11 @@ If you decide to call functions: * make sure you pick the right functions that match the user intent -{%- if tools %} {%- for t in tools %} {{- t | tojson(indent=4) }} {{- "\n\n" }} {%- endfor %} {%- endif %}<|end|> - {%- endif %} {%- for message in messages %} {%- if message.role != "system" %} diff --git a/examples/tool_chat_template_qwen3coder.jinja b/examples/tool_chat_template_qwen3coder.jinja new file mode 100644 index 0000000000000000000000000000000000000000..49b0e8d0ee7e655434942233b4b0ab2104d5247e --- /dev/null +++ b/examples/tool_chat_template_qwen3coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "<tools>" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }} + {%- if tool.description is defined %} + {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }} + {%- endif %} + {{- '\n<parameters>' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n<parameter>' }} + {{- '\n<name>' ~ param_name ~ '</name>' }} + {%- if param_fields.type is defined %} + {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n</parameter>' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n</parameters>' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n</function>' }} + {%- endfor %} + {{- "\n</tools>" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '<parameter=' + args_name + '>\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n</parameter>\n' }} + {%- endfor %} + {%- endif %} + {{- '</function>\n</tool_call>' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/mkdocs.yaml b/mkdocs.yaml index 47fe1ebce9712ef72182f040856e741eea398480..507a80c41e8b48641c9dbfd03972a62ed12bdf8d 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -129,15 +129,16 @@ markdown_extensions: - toc: permalink: true # For math rendering - - mdx_math: - enable_dollar_delimiter: true + - pymdownx.arithmatex: + generic: true extra_css: - mkdocs/stylesheets/extra.css extra_javascript: - mkdocs/javascript/run_llm_widget.js - - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML + - mkdocs/javascript/mathjax.js + - https://unpkg.com/mathjax@3.2.2/es5/tex-mml-chtml.js - mkdocs/javascript/edit_and_feedback.js - mkdocs/javascript/slack_and_forum.js diff --git a/pyproject.toml b/pyproject.toml index bf119cf1f1a833ae421dd858afba15a38d35bc56..f221d3b004d0bc96f153f93c1ace526efe150622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,14 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Intended Audience :: Developers", "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ] -requires-python = ">=3.9,<3.13" +requires-python = ">=3.9,<3.14" dynamic = [ "version", "dependencies", "optional-dependencies"] [project.urls] diff --git a/requirements/build.txt b/requirements/build.txt index 5275c9549e9a4f892364ab21b8c79082ab2a2dc1..95c119e2a4624896abce29325bbb1d4b0003613d 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -8,3 +8,4 @@ torch==2.5.1 wheel jinja2>=3.1.6 regex +build diff --git a/requirements/common.txt b/requirements/common.txt index 2dfd3f262f8f434b4e317793e2d6e0e4de2b54b0..9e7ccc3abae06354181fc6a1d34a79807a7cc8ff 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,20 +7,21 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.55.0 +transformers >= 4.55.2 tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.99.1 # For Responses API with reasoning content -pydantic >= 2.10 +pydantic >= 2.11.7 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer >= 0.10.11, < 0.11 +lm-format-enforcer == 0.11.3 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" -outlines_core == 0.2.10 +outlines_core == 0.2.10 ; platform_machine != "s390x" +outlines == 0.1.11 ; platform_machine == "s390x" # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 @@ -38,7 +39,7 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.10.2 # required for compressed-tensors +compressed-tensors == 0.11.0 # required for compressed-tensors depyf==0.19.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 6860275acab6f9f5a831a8da59552524b7917662..a48cb9fde000c06b41598e930463676b3e19127e 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -1,25 +1,24 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.60.0; python_version == '3.9' and platform_machine != "s390x" # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding +numba == 0.61.2; python_version > '3.9' and platform_machine != "s390x" # Dependencies for CPUs packaging>=24.2 setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 -torch==2.7.0; platform_system == "Darwin" -torch==2.7.0; platform_machine == "ppc64le" -torch==2.6.0; platform_machine == "aarch64" # for arm64 CPUs, torch 2.7.0 has a issue: https://github.com/vllm-project/vllm/issues/17960 +torch==2.8.0; platform_system == "Darwin" +torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" -torchaudio==2.7.0; platform_machine == "ppc64le" +torchaudio==2.8.0; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" -torchvision==0.22.0; platform_machine == "ppc64le" +torchvision==0.23.0; platform_machine == "ppc64le" datasets # for benchmark scripts # Intel Extension for PyTorch, only for x86_64 CPUs diff --git a/requirements/cuda.txt b/requirements/cuda.txt index fb30e493f80b39cf18f3ba3ee199298df39c4134..3f8b8fca3209aaca2471d192b8cb59c2ce540de4 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -6,9 +6,9 @@ numba == 0.61.2; python_version > '3.9' # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.7.1 -torchaudio==2.7.1 +torch==2.8.0 +torchaudio==2.8.0 # These must be updated alongside torch -torchvision==0.22.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -# https://github.com/facebookresearch/xformers/releases/tag/v0.0.31 -xformers==0.0.31; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7 \ No newline at end of file +torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 +xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 diff --git a/requirements/docs.txt b/requirements/docs.txt index a24b9c7e924bf2e51f36f8e9621b6e5972da0eca..d1c546398780a196fcad0b0abc75e8b928014b75 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -7,27 +7,12 @@ mkdocs-awesome-nav mkdocs-glightbox mkdocs-git-revision-date-localized-plugin mkdocs-minify-plugin -python-markdown-math regex ruff # Required for argparse hook only -f https://download.pytorch.org/whl/cpu cachetools -cbor2 -cloudpickle -fastapi msgspec -openai -openai-harmony -partial-json-parser -pillow -psutil -pybase64 pydantic -setproctitle torch -transformers -zmq -uvloop -prometheus-client diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 491fa0625963192477538d96e9d70a653cc4b480..a529bf4504e40ed552bdf454935ce1a563e212c8 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -27,7 +27,7 @@ mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb>=1.38.11, <2 # required for mteb test transformers==4.52.4 tokenizers==0.21.1 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 94201543cd4f30bc8a567ade430dcab73ab565cf..affe562c24f6b2db36523f1c0040f8ce48bc2a35 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,12 +1,12 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/rocm6.2.4 -torch==2.7.0 -torchvision==0.22.0 -torchaudio==2.7.0 +--extra-index-url https://download.pytorch.org/whl/rocm6.3 +torch==2.8.0 +torchvision==0.23.0 +torchaudio==2.8.0 -triton==3.2 +triton==3.3.0 cmake>=3.26.1,<4 packaging>=24.2 setuptools>=77.0.3,<80.0.0 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 608f39df0c57d2cc9b92747ce57c1b843506b2f4..f320e10e70ff27550f46d795675fe13352395995 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -21,7 +21,6 @@ runai-model-streamer-s3==0.11.0 # conch-triton-kernels==1.2.1 # numpy>=1.26.4 numa -python-multipart pytrie setuptools_scm>=8 cmake==3.29 @@ -30,4 +29,6 @@ torch == 2.5.1 triton == 3.0.0 flash_attn == 2.6.1 flash_mla == 1.0.0 -lmslim == 0.3.1 \ No newline at end of file +lightop == 0.5.0 +lmslim == 0.3.1 + diff --git a/requirements/test.in b/requirements/test.in index 6652bfdfe66c97564d08b988e2e2be3450307851..5b1688c76c954c3e04c56bb9c286dc2367252635 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -22,9 +22,9 @@ sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests timm >=1.0.17 # required for internvl and gemma3n-mm test -torch==2.7.1 -torchaudio==2.7.1 -torchvision==0.22.1 +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test mistral_common[image,audio] >= 1.8.2 # required for voxtral test @@ -32,9 +32,10 @@ num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +# TODO: Use lm-eval[api]==0.4.10 once released +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.55.0 +transformers==4.55.2 tokenizers==0.21.1 schemathesis>=3.39.15 # Required for openai schema test. # quantization @@ -53,3 +54,4 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 terratorch==1.1rc2 # required for PrithviMAE test +decord==0.6.0 diff --git a/requirements/test.txt b/requirements/test.txt index ff9886a3159766752f89d5f2e568ecd9a4409d91..0b728ebfb007134f4ab08bdb699ff62ec400d3bb 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -156,6 +156,8 @@ datasets==3.0.2 # mteb decorator==5.1.1 # via librosa +decord==0.6.0 + # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -408,7 +410,7 @@ lightning-utilities==0.14.3 # torchmetrics llvmlite==0.44.0 # via numba -lm-eval==0.4.8 +lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # via -r requirements/test.in lxml==5.3.0 # via @@ -493,6 +495,7 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets + # decord # einx # encodec # evaluate @@ -538,42 +541,42 @@ numpy==1.26.4 # tritonclient # vocos # xarray -nvidia-cublas-cu12==12.8.3.14 +nvidia-cublas-cu12==12.8.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.8.57 +nvidia-cuda-cupti-cu12==12.8.90 # via torch -nvidia-cuda-nvrtc-cu12==12.8.61 +nvidia-cuda-nvrtc-cu12==12.8.93 # via torch -nvidia-cuda-runtime-cu12==12.8.57 +nvidia-cuda-runtime-cu12==12.8.90 # via torch -nvidia-cudnn-cu12==9.7.1.26 +nvidia-cudnn-cu12==9.10.2.21 # via torch -nvidia-cufft-cu12==11.3.3.41 +nvidia-cufft-cu12==11.3.3.83 # via torch -nvidia-cufile-cu12==1.13.0.11 +nvidia-cufile-cu12==1.13.1.3 # via torch -nvidia-curand-cu12==10.3.9.55 +nvidia-curand-cu12==10.3.9.90 # via torch -nvidia-cusolver-cu12==11.7.2.55 +nvidia-cusolver-cu12==11.7.3.90 # via torch -nvidia-cusparse-cu12==12.5.7.53 +nvidia-cusparse-cu12==12.5.8.93 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.3 +nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.26.2 +nvidia-nccl-cu12==2.27.3 # via torch -nvidia-nvjitlink-cu12==12.8.61 +nvidia-nvjitlink-cu12==12.8.93 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.8.55 +nvidia-nvtx-cu12==12.8.90 # via torch omegaconf==2.3.0 # via @@ -742,7 +745,7 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.11.5 +pydantic==2.11.7 # via # -r requirements/test.in # albumentations @@ -1066,7 +1069,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.7.1+cu128 +torch==2.8.0+cu128 # via # -r requirements/test.in # accelerate @@ -1095,7 +1098,7 @@ torch==2.7.1+cu128 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.7.1+cu128 +torchaudio==2.8.0+cu128 # via # -r requirements/test.in # encodec @@ -1108,7 +1111,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.22.1+cu128 +torchvision==0.23.0+cu128 # via # -r requirements/test.in # lightly @@ -1139,7 +1142,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.55.0 +transformers==4.55.2 # via # -r requirements/test.in # genai-perf @@ -1149,7 +1152,7 @@ transformers==4.55.0 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.3.1 +triton==3.4.0 # via torch tritonclient==2.51.0 # via diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 7bb77c4a99636fa37e977dcd7ac772e761b77a4c..7ea239b48ea2622edc2494533aa3550ac59c9717 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -11,6 +11,7 @@ ray[default] ray[data] setuptools==78.1.0 nixl==0.3.0 +tpu_info==0.4.0 # Install torch_xla --pre diff --git a/setup.py b/setup.py index 09bd4e3f219f24516185f1e249cea425dccea4d3..fe433656252d052f276a326ecc48530eeb1be067 100644 --- a/setup.py +++ b/setup.py @@ -550,8 +550,8 @@ def get_version_add(sha: Optional[str] = None) -> str: new_version_content = f""" try: - __version__ = "0.10.1.1" - __version_tuple__ = (0, 10, 1, 1) + __version__ = "0.10.2.rc1" + __version_tuple__ = (0, 10, 2, rc1) __hcu_version__ = f'0.10.1.1+{version}' from vllm.version import __version__, __version_tuple__, __hcu_version__ @@ -765,16 +765,25 @@ if envs.VLLM_USE_PRECOMPILED: if wheel_location is not None: wheel_url = wheel_location else: + import platform + arch = platform.machine() + if arch == "x86_64": + wheel_tag = "manylinux1_x86_64" + elif arch == "aarch64": + wheel_tag = "manylinux2014_aarch64" + else: + raise ValueError(f"Unsupported architecture: {arch}") base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() - wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" from urllib.request import urlopen try: with urlopen(wheel_url) as resp: if resp.status != 200: - wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + wheel_url = nightly_wheel_url except Exception as e: print(f"[warn] Falling back to nightly wheel: {e}") - wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + wheel_url = nightly_wheel_url patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( wheel_url) @@ -807,7 +816,9 @@ setup( "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.11"], + "flashinfer": ["flashinfer-python==0.2.14.post1"], + # Optional deps for AMD FP4 quantization support + "petit-kernel": ["petit-kernel"], }, cmdclass=cmdclass, package_data=package_data, diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 911f17e569c15524c17e98846e9a038874b39566..9383917d4a7852314d605551b1c85b0586f21be7 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -12,7 +12,6 @@ import pytest import torch from vllm import LLM, envs -from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import HfRunner, VllmRunner @@ -83,11 +82,7 @@ def test_models( "VLLM_USE_V1") and envs.VLLM_USE_V1: pytest.skip("enable_prompt_embeds is not supported in v1.") - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - - if backend in ("XFORMERS", - "FLASHINFER") and model == os.path.join(models_path_prefix, "google/gemma-2-2b-it"): + if backend == "XFORMERS" and model == os.path.join(models_path_prefix, "google/gemma-2-2b-it"): pytest.skip( f"{backend} does not support gemma2 with full context length.") @@ -146,6 +141,7 @@ def test_models( ) + # @multi_gpu_test(num_gpus=2) # @pytest.mark.parametrize( # "model, distributed_executor_backend, attention_backend, " @@ -162,8 +158,6 @@ def test_models( # ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), # ("distilbert/distilgpt2", "ray", "", "A100", {}), # ("distilbert/distilgpt2", "mp", "", "A100", {}), -# ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}), -# ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}), # ]) # @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) # def test_models_distributed( diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index a130df8dd6b1024b71c72b549e57b3bacc90fcf5..59b4b02c53e6ec6c99a6b59f3642ad050350c53b 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -176,4 +176,35 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): output3 = llm.generate(prompt, sampling_params) # cmp output - assert output[0].outputs[0].text == output3[0].outputs[0].text \ No newline at end of file + assert output[0].outputs[0].text == output3[0].outputs[0].text + + +@create_new_process_for_each_test() +def test_deep_sleep(): + model = "Qwen/Qwen3-0.6B" + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + llm = LLM(model, enable_sleep_mode=True) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + # Put the engine to deep sleep + llm.sleep(level=2) + + free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline + assert used_bytes < 3 * GiB_bytes + + llm.wake_up(tags=["weights"]) + llm.collective_rpc("reload_weights") + free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline + assert used_bytes < 4 * GiB_bytes + + # now allocate kv cache and cuda graph memory + llm.wake_up(tags=["kv_cache"]) + output2 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..26cae369cdd5dd0a9d6faa18cd750b37385ac104 --- /dev/null +++ b/tests/benchmarks/test_random_dataset.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +from typing import Any, NamedTuple, Optional, cast + +import numpy as np +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, + SampleRequest) + + +@pytest.fixture(scope="session") +def hf_tokenizer() -> PreTrainedTokenizerBase: + # Use a small, commonly available tokenizer + return AutoTokenizer.from_pretrained("gpt2") + + +class Params(NamedTuple): + num_requests: int + prefix_len: int + range_ratio: float + input_len: int + output_len: int + + +@pytest.fixture(scope="session") +def random_dataset_params() -> Params: + return Params(num_requests=16, + prefix_len=7, + range_ratio=0.3, + input_len=50, + output_len=20) + + +def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: + """Project a SampleRequest into a comparable tuple.""" + return (req.prompt, req.prompt_len, req.expected_output_len) + + +def _collect_samples(dataset: RandomDataset, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = 16, + prefix_len: int = 7, + range_ratio: float = 0.3, + input_len: int = 50, + output_len: int = 20) -> list[tuple[str, int, int]]: + samples = dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + ) + return [_fingerprint_sample(s) for s in samples] + + +@pytest.mark.benchmark +def test_random_dataset_same_seed( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Same seed should yield identical outputs, even if global RNGs change. + + This guards against accidental reliance on Python's random or np.random + in RandomDataset after moving to numpy.default_rng. + """ + p = random_dataset_params + common_seed = 123 + dataset_a = RandomDataset(random_seed=common_seed) + dataset_b = RandomDataset(random_seed=common_seed) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + # Perturb global RNG state to ensure isolation + random.seed(999) + _ = [random.random() for _ in range(100)] + np.random.seed(888) + _ = [np.random.random() for _ in range(100)] + + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a == b + +@pytest.mark.benchmark +def test_random_dataset_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Different seeds should change outputs with overwhelming likelihood.""" + p = random_dataset_params + seed_a = 0 + dataset_a = RandomDataset(random_seed=seed_a) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + seed_b = 999 + dataset_b = RandomDataset(random_seed=seed_b) + # Perturb global RNG with same seed as dataset_a to ensure isolation + random.seed(seed_a) + np.random.seed(seed_a) + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a != b + + +# ----------------------------- +# RandomMultiModalDataset tests +# ----------------------------- + +def _mm_fingerprint_sample( + req: SampleRequest, +) -> tuple[str, int, int, int, list[str]]: + """Create a compact fingerprint for multimodal samples. + + Includes: + - prompt string + - prompt_len + - expected_output_len + - count of multimodal items + - per-item type and URL prefix (e.g., 'data:image/jpeg;base64,') + """ + items = req.multi_modal_data or [] + item_prefixes: list[str] = [] + for it in items: + if isinstance(it, dict) and it.get("type") == "image_url": + url = it.get("image_url", {}).get("url", "") + # Only keep a short identifying prefix to avoid huge strings + item_prefixes.append(f"image:{url[:22]}") + elif isinstance(it, dict) and it.get("type") == "video_url": + url = it.get("video_url", {}).get("url", "") + item_prefixes.append(f"video:{url[:22]}") + else: + item_prefixes.append("unknown:") + return (req.prompt, req.prompt_len, req.expected_output_len, len(items), + item_prefixes) + + +def _collect_mm_samples( + dataset: RandomMultiModalDataset, + tokenizer: PreTrainedTokenizerBase, + *, + num_requests: int = 8, + prefix_len: int = 3, + range_ratio: float = 0.0, + input_len: int = 20, + output_len: int = 5, + base_items_per_request: int = 2, + num_mm_items_range_ratio: float = 0.0, + limit_mm_per_prompt: Optional[dict[str, int]] = None, + bucket_config: Optional[dict[tuple[int, int, int], float]] = None, + enable_multimodal_chat: bool = False, +) -> list[SampleRequest]: + if limit_mm_per_prompt is None: + limit_mm_per_prompt = {"image": 5, "video": 0} + if bucket_config is None: + bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5} + return dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + base_items_per_request=base_items_per_request, + num_mm_items_range_ratio=num_mm_items_range_ratio, + limit_mm_per_prompt=limit_mm_per_prompt, + bucket_config=bucket_config, + enable_multimodal_chat=enable_multimodal_chat, + ) + + +@pytest.mark.benchmark +def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None: + seed = 42 + ds_a = RandomMultiModalDataset(random_seed=seed) + ds_b = RandomMultiModalDataset(random_seed=seed) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa == fb + + +@pytest.mark.benchmark +def test_random_mm_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds_a = RandomMultiModalDataset(random_seed=0) + ds_b = RandomMultiModalDataset(random_seed=999) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa != fb + +@pytest.mark.benchmark +def test_random_mm_respects_limits( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Requesting 3 items with a per-prompt limit of 1 should error per current + # design (dataset refuses to silently clamp below the requested baseline). + with pytest.raises(ValueError): + _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=12, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + + +@pytest.mark.benchmark +def test_random_mm_zero_prob_entries_are_removed( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Second bucket has zero probability and should be ignored after + # normalization + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=6, + base_items_per_request=2, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 10, "video": 0}, + bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0}, + ) + for s in samples: + assert isinstance(s.multi_modal_data, list) + typed_mm = cast(list[dict[str, Any]], s.multi_modal_data) + for it in typed_mm: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=0, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 5, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + for s in samples: + assert s.multi_modal_data == [] + +@pytest.mark.benchmark +def test_random_mm_num_items_per_prompt( + hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Fixed number of images per prompt + # set num_mm_items_range_ratio to 0.0 + # TODO: modify video values when video sampling is implemented + samples_fixed_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 3, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with 3 mm items per prompt + assert len(samples_fixed_items) == 5 + for s in samples_fixed_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) == 3 + for it in mm_data: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_bucket_config_not_mutated( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + + ds = RandomMultiModalDataset(random_seed=0) + # This bucket config is not normalized to sum to 1 + # and has more buckets than requested images + original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3} + # Keep a snapshot to compare after sampling + snapshot = dict(original) + + _ = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=4, + base_items_per_request=1, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config=original, + ) + + # Ensure the original dict content is unchanged + assert original == snapshot + + + # Vary number of mm items per prompt + # set num_mm_items_range_ratio to 0.5 + samples_varying_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=2, + num_mm_items_range_ratio=0.5, + limit_mm_per_prompt={"image": 4, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with less than 4 mm items per prompt + # but at least 1 mm item per prompt + assert len(samples_varying_items) == 5 + for s in samples_varying_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) <= 4 + assert len(mm_data) >= 1 + for it in mm_data: + assert it.get("type") == "image_url" diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index e460d70951786b1f166992e8e64e0a543dfe175a..f5e2d9ddb7528d52cb15c2aa397a3695d46b3404 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, support_torch_compile) -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) -from vllm.envs import VLLM_USE_V1 -from vllm.forward_context import set_forward_context +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel): return x -def test_ignore_torch_compile_decorator(): - assert VLLM_USE_V1 - - # piecewise - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) - - @support_torch_compile - class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + x - attn_output = torch.empty_like(x) - torch.ops.silly.attention(x, x, x, attn_output) - x = attn_output - x = x * 3 - return x - - @ignore_torch_compile - class B(A): - ... - - @support_torch_compile - class C(B): - ... - - with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() - - # A has support_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - # first run is for compile - mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - # run cudagraph captured sizes - mod_A(torch.randn(2, MLP_SIZE).cuda()) - mod_A(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() - - # B's ignore_torch_compile should override A's support_torch_compile - with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, - ), set_forward_context({}, vllm_config=vllm_config): - mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_B(torch.randn(2, MLP_SIZE).cuda()) - mod_B(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() - - # C's support_torch_compile should override B's ignore_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_C(torch.randn(2, MLP_SIZE).cuda()) - mod_C(torch.randn(1, MLP_SIZE).cuda()) - - @torch.inference_mode -def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor): +def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, + cudagraph_runtime_mode: CUDAGraphMode): with set_forward_context({}, vllm_config=vllm_config): - # First run is for compile + # warmup for the model with cudagraph_mode NONE model(inputs) - # Run CUDAGraph captured sizes - model(inputs[:2]) - model(inputs[:1]) - - output = model(inputs[:2]) + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(inputs[:2]) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(inputs[:1]) + + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(inputs[:2]) output = output.cpu() return output.cpu() @@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): splitting_ops=["silly.attention"], cudagraph_capture_sizes=[1, 2], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_cudagraph_captured=8, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # no compile or cudagraph vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.NO_COMPILATION, )) + cudagraph_runtime_mode = CUDAGraphMode.NONE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # piecewise compile without CUDA graph vllm_config = VllmConfig(compilation_config=CompilationConfig( @@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): use_cudagraph=False, splitting_ops=["silly.attention"], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=4, num_cudagraph_captured=0, # no cudagraph captured ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # Generally don't expect outputs with and without inductor # to be bitwise equivalent diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 628594d74956a881e41255e758bdca6888c16ceb..6d1d618440f083d5778a0d96616d6288516a1348 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -30,15 +30,15 @@ class TestSetting: "test_setting", [ # basic llama model - # TestSetting( - # model=os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), - # model_args=["--max-model-len", "2048"], - # pp_size=2, - # tp_size=2, - # attn_backend="FLASHINFER", - # method="generate", - # fullgraph=True, - # ), + TestSetting( + model=os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), + model_args=["--max-model-len", "2048"], + pp_size=2, + tp_size=2, + attn_backend="FLASH_ATTN", + method="generate", + fullgraph=True, + ), # llama model with quantization TestSetting( model=os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"), diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..51f8ddd566d5646d66027f268de45041ebb3d17e --- /dev/null +++ b/tests/compile/test_decorator.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import (ignore_torch_compile, + support_torch_compile) +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + CUDAGraphMode, VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 32 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + cudagraph_runtime_mode: CUDAGraphMode): + with set_forward_context({}, vllm_config=vllm_config): + # warmup for the model with cudagraph_mode NONE + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(torch.randn(2, MLP_SIZE).cuda()) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(torch.randn(1, MLP_SIZE).cuda()) + + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(torch.randn(2, MLP_SIZE).cuda()) + + output = output.cpu() + return output.cpu() + + +def test_ignore_torch_compile_decorator(): + # piecewise + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + @support_torch_compile + class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + @ignore_torch_compile + class B(A): + ... + + @support_torch_compile + class C(B): + ... + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + + # B's ignore_torch_compile should override A's support_torch_compile + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ): + run_model(vllm_config, mod_B, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + + # C's support_torch_compile should override B's ignore_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_C, cudagraph_runtime_mode) + + +# Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=True +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class B(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x + x + return x + + +# Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=False +@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) +class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mod1(x) + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = self.mod2(x) + return x + + +def test_conditional_compile_enable_if(): + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=True, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile but enable_if fn returns False + # enalbe_if will be True for B, so we expect mod1 and mod2 + # to be compiled + with compilation_counter.expect( + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + # 3 piecewise graphs per instance of B() + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + # Set kv_sharing_fast_prefill=False + # which will cause A to be compiled and B to not be compiled + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=False, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=7, + # 3 attn ops and 4 non-attn ops + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index a2fc6ffeb8b26047d4e6fe9d11f2f1a5e5f7184e..84178344a5f3685cff1895393f8c5846bd0f8ccd 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -53,12 +53,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): "quantization": "gptq_marlin_24" })) - if is_quant_method_supported("marlin"): - TEST_MODELS.append( - ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { - "quantization": "marlin" - })) - if not current_platform.is_rocm() and is_quant_method_supported("awq"): TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { "quantization": "AWQ" diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4c3cf6c2a10cfbdc344df120ceaff2d54e1f11af..dd31e0db1f59f282a10e409b4de61f0179c441ab 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -148,7 +148,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index a6baa97fe699065dd72b7b3c12e04a8875fe2d55..fb9f9dde22799fd55334c916f5b2b0c305adb832 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, - use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, diff --git a/tests/compile/untest_functionalization.py b/tests/compile/untest_functionalization.py index 368ffa2db4e1718e7e99193f2117f1927e4b1bd4..74eccc9bc4acea619b135d9613b1422438c71f05 100644 --- a/tests/compile/untest_functionalization.py +++ b/tests/compile/untest_functionalization.py @@ -8,11 +8,12 @@ import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, - kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fusion import FUSED_OPS, FusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from .backend import TestBackend diff --git a/tests/compile/untest_fusion.py b/tests/compile/untest_fusion.py index 4a3820e20fd892bd7a71a93904ef48bfdc9d2006..c4229f93464ac93eb2799bdd4ba221d99846176d 100644 --- a/tests/compile/untest_fusion.py +++ b/tests/compile/untest_fusion.py @@ -7,13 +7,15 @@ import torch import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass, GroupShape, QuantKey) + FusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) + Fp8LinearOp, maybe_create_device_identity) from vllm.platforms import current_platform from .backend import TestBackend @@ -24,16 +26,14 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestModel(torch.nn.Module): def __init__(self, hidden_size: int, eps: float, static: bool, - cutlass_fp8_enabled: bool, *args, **kwargs): + force_fp8_e4m3fnuz: bool, *args, **kwargs): super().__init__(*args, **kwargs) - self.cutlass_fp8_enabled = cutlass_fp8_enabled + self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - self.key = QuantKey(dtype=FP8_DTYPE, - static=static, - group_shape=group_shape, - symmetric=True) + quant_scale = ScaleDesc(torch.float32, static, group_shape) + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: @@ -43,7 +43,7 @@ class TestModel(torch.nn.Module): for _ in range(2) ] self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, act_quant_static=static, act_quant_group_shape=group_shape, ) @@ -81,12 +81,11 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) +@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - cutlass_fp8_enabled): + force_fp8_e4m3fnuz): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -103,7 +102,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) - model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) + model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) diff --git a/tests/compile/untest_fusion_attn.py b/tests/compile/untest_fusion_attn.py index 70750eb9ac4eecc1109a878a1999ddbc1a5e0cda..dba668cfa16a60850bc75f92f20e166d3da01dc2 100644 --- a/tests/compile/untest_fusion_attn.py +++ b/tests/compile/untest_fusion_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy from typing import Optional import pytest @@ -7,13 +8,29 @@ import torch._dynamo from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata) from vllm import LLM, SamplingParams -from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +from vllm.attention import Attention +from vllm.attention.selector import global_force_attn_backend_context_manager +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + ModelConfig, PassConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) from vllm.platforms import current_platform +from vllm.v1.kv_cache_interface import AttentionSpec + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 # globals needed for string-import custom Dynamo backend field backend: Optional[TestBackend] = None @@ -90,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # check support attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) + layer.impl.fused_output_quant_supported(quant_key) for key, layer in compile_config.static_forward_context.items() ] @@ -132,3 +147,309 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # Reset backend to make sure llm2 gets released backend = None + + +class AttentionQuantPatternModel(torch.nn.Module): + """Base model for AttentionQuantPattern fusion.""" + + def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype, device: torch.device, + vllm_config: VllmConfig, **kwargs): + super().__init__() + self.num_qo_heads = num_qo_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.kv_cache_dtype = kv_cache_dtype + self.device = device + self.vllm_config = vllm_config + + self.attn = Attention( + num_heads=self.num_qo_heads, + head_size=self.head_size, + scale=1.0 / (self.head_size**0.5), + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix="model.layers.0.self_attn.attn", + ) + + self.block_size = 16 + + # Initialize attn MetadataBuilder + self.builder = self.attn.attn_backend.get_builder_cls()( + kv_cache_spec=AttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_dtype, + use_mla=False, + ), + layer_names=[self.attn.layer_name], + vllm_config=self.vllm_config, + device=self.device, + ) + + def build_attn_metadata(self, batch_size: int): + """Initialize attention metadata.""" + + # Create common attn metadata + batch_spec = BatchSpec(seq_lens=[1] * batch_size, + query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + self.block_size, + self.device, + arange_block_indices=True) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - + 1) // self.block_size + num_blocks = batch_size * max_blocks + + # Create dummy KV cache for FlashInfer TRTLLM + # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = torch.zeros(num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + self.attn.kv_cache = [kv_cache] + + # Build attn metadata + self.attn_metadata = self.builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata) + + return self.attn_metadata + + +class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionFp8StaticQuantPattern fusion.""" + + quant_key = kFp8StaticTensorSym + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.quant_key.scale.static, + act_quant_group_shape=self.quant_key.scale.group_shape) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randn(hidden_size, hidden_size).to( + dtype=FP8_DTYPE, device=self.device).t(), + "wscale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + return self.fp8_linear.apply(input=attn_output, + weight=self.w["weight"], + weight_scale=self.w["wscale"], + input_scale=self.w["scale"]) + + +class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionNvfp4QuantPattern fusion.""" + + quant_key = kNvfp4Quant + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE, + device=self.device), + "wscale_swizzled": + torch.randn(hidden_size, hidden_size // 16).to( + dtype=FP8_DTYPE, device=self.device), + "wscale": + torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([0.002], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + quant_output, output_block_scale = scaled_fp4_quant( + attn_output, 1 / self.w["scale"]) + return cutlass_scaled_fp4_mm(a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype) + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("batch_size", [7, 256, 533]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_name, model_class", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)]) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") +@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), + reason="Only test on SM100(Blackwell)") +def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, + head_size: int, batch_size: int, + dtype: torch.dtype, model_name: str, + model_class: type[AttentionQuantPatternModel], + backend: _Backend, monkeypatch, dist_init): + """Test AttentionStaticQuantPattern fusion pass""" + + monkeypatch.setenv("VLLM_USE_V1", "1") + + device = torch.device("cuda:0") + torch.manual_seed(42) + + vllm_config = VllmConfig( + model_config=ModelConfig( + model=model_name, + max_model_len=2048, + ), + scheduler_config=SchedulerConfig(max_num_seqs=1024), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+quant_fp8"], + ), + cache_config=CacheConfig(cache_dtype="fp8")) + + # Create test inputs + q = torch.randn(batch_size, + num_qo_heads * head_size, + dtype=dtype, + device=device) + k = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + v = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + + # Mark first dimension as dynamic for realistic testing + torch._dynamo.mark_dynamic(q, 0) + torch._dynamo.mark_dynamic(k, 0) + torch._dynamo.mark_dynamic(v, 0) + + # Run model directly without compilation and fusion + vllm_config_unfused = copy.deepcopy(vllm_config) + with set_current_vllm_config(vllm_config_unfused), set_forward_context( + attn_metadata=None, vllm_config=vllm_config_unfused + ), global_force_attn_backend_context_manager(backend): + model_unfused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config_unfused) + model_unfused = model_unfused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_unfused.build_attn_metadata( + batch_size) + + # Run model directly without compilation and fusion + result_unfused = model_unfused(q, k, v) + + # Run model with attn fusion enabled + vllm_config.compilation_config.pass_config = PassConfig( + enable_attn_fusion=True, enable_noop=True) + with set_current_vllm_config(vllm_config), set_forward_context( + attn_metadata=None, vllm_config=vllm_config + ), global_force_attn_backend_context_manager(backend): + model_fused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config, + w=model_unfused.w) + model_fused = model_fused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + + # Create test backend with fusion passes enabled + noop_pass = NoOpEliminationPass(vllm_config) + attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw + ) + test_backend = TestBackend(noop_pass, attn_pass) + + # Compile model with fusion enabled + model_compiled = torch.compile(model_fused, + backend=test_backend, + fullgraph=True) + assert model_compiled.attn._o_scale_float is None + result_fused_1 = model_compiled(q, k, v) + + # After the 1st round of the forward pass, output quant scale should be + # loaded into the attn layer's _o_scale_float, the 2nd round should + # reuse the loaded _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is not None + + # Check attn fusion support + quant_key = model_class.quant_key + attn_fusion_supported = [ + layer.impl.fused_output_quant_supported(quant_key) for key, layer in + vllm_config.compilation_config.static_forward_context.items() + ] + if any(attn_fusion_supported): + # Check quantization ops in the graph before and after fusion + test_backend.check_before_ops([QUANT_OPS[quant_key]], + fully_replaced=True) + + # Check attention ops in the graph before and after fusion + attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, + test_backend.graph_post_pass)) + + assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" + assert len(attn_nodes_pre) == len(attn_nodes_post), \ + "Should have same number of attention nodes before and after fusion" + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ + "Attention should not have output_scale before fusion" + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ + "Attention should have output_scale after fusion" + + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale before fusion" + if quant_key.dtype == FP8_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale after FP8 fusion" + elif quant_key.dtype == FP4_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ + "Attention should have output_block_scale after FP4 fusion" # noqa: E501 + + # Check that results are closed + torch.testing.assert_close(result_unfused, + result_fused_1, + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(result_unfused, + result_fused_2, + atol=1e-2, + rtol=1e-2) diff --git a/tests/compile/untest_silu_mul_quant_fusion.py b/tests/compile/untest_silu_mul_quant_fusion.py index 5351a3cf35ba51f777206238a52e0fb37575d040..fcc2589e421167567b48ec240dfedeffec5b6408 100644 --- a/tests/compile/untest_silu_mul_quant_fusion.py +++ b/tests/compile/untest_silu_mul_quant_fusion.py @@ -4,35 +4,44 @@ import pytest import torch import vllm.envs as envs -from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +# yapf conflicts with isort for this block +# yapf: disable +from vllm.compilation.activation_quant_fusion import ( + FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass) +# yapf: enable +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp) + Fp8LinearOp) from vllm.platforms import current_platform from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, - **kwargs): - super().__init__(*args, **kwargs) +def is_nvfp4_supported(): + return current_platform.has_device_capability(100) + + +class TestSiluMulFp8QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs): + super().__init__() self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) - self.w = (torch.rand( - hidden_size, - hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR, ) @@ -45,15 +54,56 @@ class TestModel(torch.nn.Module): input_scale=self.wscale) return x2 + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] + + def ops_in_model_after(self): + return [FUSED_OPS[kFp8StaticTensorSym]] + + +class TestSiluMulNvfp4QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.w = torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE) + self.wscale = torch.randn(hidden_size, + hidden_size // 16).to(dtype=FP8_DTYPE) + self.wscale2 = torch.rand(1, dtype=torch.float32) + self.scale = torch.rand(1, dtype=torch.float32) -@pytest.mark.parametrize("num_tokens", [256]) -@pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) + def forward(self, x): + y = self.silu_and_mul(x) + y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale) + out = cutlass_scaled_fp4_mm(a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.wscale, + alpha=self.scale * self.wscale2, + out_dtype=y.dtype) + return out + + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + + def ops_in_model_after(self): + return [FUSED_OPS[kNvfp4Quant]] + + +@pytest.mark.parametrize("num_tokens", [64]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize( + "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]) +@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, - cutlass_fp8_enabled): +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, + force_fp8_e4m3fnuz): + if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz: + pytest.skip("Duplicate tests for NVFP4") + torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -64,7 +114,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = TestModel(hidden_size, cutlass_fp8_enabled) + model = model_class(hidden_size=hidden_size, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size * 2) @@ -81,17 +132,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, atol=1e-3, rtol=1e-3) - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes - - silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default - - # In pre-nodes, fp8 quant should be present and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None - find_auto_fn(pre_nodes, fp8_quant) + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) - # In post-nodes, fused kernels should be present and fp8 quant should not - find_auto_fn(post_nodes, silu_and_mul_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/conftest.py b/tests/conftest.py index 455479d4666d77446c17e4dc9e51b31851344750..ca58bd2bfcf42e8e5b94cec0dc310fd025283da9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +import math import os import tempfile from enum import Enum - -from typing import Any, Callable, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast import pytest import pytest_html @@ -37,7 +37,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams - +from vllm.sequence import Logprob from vllm.transformers_utils.utils import maybe_model_redirect from .utils import models_path_prefix @@ -460,9 +460,16 @@ class HfRunner: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] + problem_type = getattr(self.config, "problem_type", "") + for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - logits = output.logits.softmax(dim=-1)[0].tolist() + if problem_type == "regression": + logits = output.logits[0].tolist() + elif problem_type == "multi_label_classification": + logits = output.logits.sigmoid()[0].tolist() + else: + logits = output.logits.softmax(dim=-1)[0].tolist() outputs.append(logits) return outputs @@ -600,7 +607,7 @@ class HfRunner: def _hidden_states_to_logprobs( self, hidden_states: tuple[tuple[torch.Tensor, ...], ...], - num_logprobs: int, + num_logprobs: Optional[int], ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) @@ -628,7 +635,7 @@ class HfRunner: self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, @@ -675,7 +682,7 @@ class HfRunner: self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: @@ -964,7 +971,7 @@ class VllmRunner: self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, @@ -989,11 +996,40 @@ class VllmRunner: videos=videos, **kwargs) + def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: + """ + Return the perplexity score associated with generating the prompts + + :param prompts: list of prompts to score + :return: perplexity score of each prompt + """ + outputs = self.generate_greedy_logprobs(prompts, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0) + + perplexities = [] + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs)) + perplexities.append(perplexity) + + return perplexities + def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, ) -> Union[list[TokensTextLogprobs], @@ -1020,15 +1056,17 @@ class VllmRunner: images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + concurrency_limit: Optional[int] = None, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - outputs = self.llm.beam_search( - inputs, - BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) + outputs = self.llm.beam_search(inputs, + BeamSearchParams(beam_width=beam_width, + max_tokens=max_tokens), + concurrency_limit=concurrency_limit) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] @@ -1086,6 +1124,9 @@ class VllmRunner: return self.llm.llm_engine.collective_rpc(_apply_model) + def get_llm(self) -> LLM: + return self.llm + def __enter__(self): return self diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index f931b425f584bc68302b922c71f9464cde88d3da..a9d7c635af391c818b99a06d73c0b736614893b6 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -34,7 +34,7 @@ BLOCK_SIZE = 16 @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, batch_size, seed, backend, monkeypatch): """ @@ -45,8 +45,6 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, Additionally, we compare the results of the v1 and v2 managers. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") if backend == "XFORMERS" and current_platform.is_rocm(): pytest.skip("Xformers does not support ROCm/HIP.") @@ -98,7 +96,7 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, backend, monkeypatch): """ @@ -109,8 +107,6 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, The results with and without chunked prefill are not the same due to numerical instabilities. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") if backend == "XFORMERS" and current_platform.is_rocm(): pytest.skip("Xformers does not support ROCm/HIP.") override_backend_env_variable(monkeypatch, backend) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index e2cb579e22dc4167852af6b52a9811fb229c67c6..8d84cc2d0ffe6dca38a01234ac4a88738e8620f1 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -18,7 +18,8 @@ from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) -from ..utils import init_test_distributed_environment, multi_process_parallel +from ..utils import (init_test_distributed_environment, multi_gpu_test, + multi_process_parallel) @ray.remote(num_gpus=1, max_calls=1) @@ -226,8 +227,7 @@ def send_recv_test_worker( torch.testing.assert_close(test_tensor, recv_tensor) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("test_target", [ all_reduce_test_worker, all_gather_test_worker, @@ -241,8 +241,7 @@ def test_multi_process_tensor_parallel( multi_process_parallel(monkeypatch, tp_size, 1, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize( "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) @@ -254,8 +253,7 @@ def test_multi_process_pipeline_parallel( multi_process_parallel(monkeypatch, 1, pp_size, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@multi_gpu_test(num_gpus=4) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize("test_target", [ diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 0d6af75526a5fc1e5fabd256b3384abc0a5d2051..7b758628f6e6eaad99d8426bfdd9724c16ef32df 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -118,6 +118,8 @@ class PPTestSettings: multi_node_only: bool = False, load_format: Optional[str] = None, ): + vllm_major_versions = ["1"] if runner == "pooling" else ["0"] + return PPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, @@ -126,7 +128,7 @@ class PPTestSettings: chunked_prefill=False), ], distributed_backends=["mp"], - vllm_major_versions=["0"], + vllm_major_versions=vllm_major_versions, runner=runner, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -213,7 +215,9 @@ TEXT_GENERATION_MODELS = { EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] os.path.join(models_path_prefix, "intfloat/e5-mistral-7b-instruct"): PPTestSettings.fast(runner="pooling"), - os.path.join(models_path_prefix, "BAAI/bge-multilingual-gemma2"): PPTestSettings.fast(runner="pooling"), + # TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883 + # is fixed + #"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"): PPTestSettings.fast( load_format="dummy", runner="pooling" ), @@ -233,6 +237,7 @@ MULTIMODAL_MODELS = { os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"): PPTestSettings.fast(), os.path.join(models_path_prefix, "allenai/Molmo-7B-D-0924"): PPTestSettings.fast(), os.path.join(models_path_prefix, "AIDC-AI/Ovis2-1B"): PPTestSettings.fast(), + os.path.join(models_path_prefix, "AIDC-AI/Ovis2.5-2B"): PPTestSettings.fast(), os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"): PPTestSettings.fast(), os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"): PPTestSettings.fast(load_format="dummy"), os.path.join(models_path_prefix, "Qwen/Qwen-VL-Chat"): PPTestSettings.fast(), diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index 4785eeac0087cff57952e97eb3103a490191064c..d44382d14ed69e7488ce8a3536295923c21e6893 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: ]) @pytest.mark.parametrize("ATTN_BACKEND", [ "FLASH_ATTN", - # "FLASHINFER", ]) @create_new_process_for_each_test() def test_pp_cudagraph( diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 49b8eddecb4a9b9f07175c1d6505a06811ec6048..c93b436f384b9aebb82971ce9a8a2c61f7d028e5 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -292,7 +292,7 @@ SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", ] diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py new file mode 100644 index 0000000000000000000000000000000000000000..5a804a389123bb5880eb9c16a6838208d42f84b6 --- /dev/null +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, + init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + symm_mem_comm = cuda_communicator.symm_mem_comm + if symm_mem_comm is None or symm_mem_comm.disabled: + pytest.skip("SymmMemCommunicator is not available or disabled.") + + inp_direct_symm_mem = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): + pytest.skip( + "SymmMemCommunicator isn't used for this world and input size." + ) + + original_inp_direct_symm_mem = inp_direct_symm_mem.clone() + out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) + assert out_direct_symm_mem is not None + + group = get_tensor_model_parallel_group().device_group + dist.all_reduce(original_inp_direct_symm_mem, group=group) + torch.testing.assert_close(out_direct_symm_mem, + original_inp_direct_symm_mem, + atol=2.5, + rtol=0.1) + + # Test tensor_model_parallel_all_reduce which should use symm_mem + inp_tensor_parallel = torch.randint(-23, + 1, (test_size_elements, ), + dtype=dtype, + device=device) + original_inp_tensor_parallel = inp_tensor_parallel.clone() + out_tensor_parallel = tensor_model_parallel_all_reduce( + inp_tensor_parallel) + dist.all_reduce(original_inp_tensor_parallel, group=group) + torch.testing.assert_close(out_tensor_parallel, + original_inp_tensor_parallel, + atol=2.5, + rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") + + mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) + cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 552cc8db84be1df77c307635fafffe1ce393fc0e..28ca6f9ef87541170e24917f181b2aa3001d156c 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -20,10 +20,9 @@ def text_llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -90,10 +89,9 @@ def vision_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -160,10 +158,9 @@ def thinking_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py index 71e76abcb7d2c27dd38e38987e2f777c0d57162e..6c0c9cd0158010adb7710e11804afd8a01b05de8 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/llm/test_classify.py @@ -16,14 +16,6 @@ MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -35,10 +27,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -71,3 +62,9 @@ def test_encode_api(llm: LLM): err_msg = "pooling_task must be one of.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompts, use_tqdm=False) + + +def test_score_api(llm: LLM): + err_msg = "Score API is only enabled for num_labels == 1." + with pytest.raises(ValueError, match=err_msg): + llm.score("ping", "pong", use_tqdm=False) diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py index ba20d7b9548efc311cd77f99347f0cd05305c304..485f04ed6d8491fc8106e08fed3f1173ef393bb6 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/llm/test_embedding.py @@ -26,10 +26,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index e20819f68ecee5d514fd1a601ec5629db10bd4fd..011a67f0be49f5f37e69c9d66b4324e94dd72cf8 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -6,12 +6,10 @@ import weakref import pytest import os -from vllm import LLM, PoolingParams, PoolingRequestOutput +from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory from ...utils import models_path_prefix -from ...models.utils import check_embeddings_close - MODEL_NAME = os.path.join(models_path_prefix, "intfloat/multilingual-e5-small") PROMPTS = [ @@ -31,14 +29,6 @@ TOKEN_IDS = [ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -50,57 +40,13 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_match(o1: list[PoolingRequestOutput], - o2: list[PoolingRequestOutput]): - check_embeddings_close( - embeddings_0_lst=[o.outputs.data for o in o1], - embeddings_1_lst=[o.outputs.data for o in o2], - name_0="hf", - name_1="vllm", - tol=1e-2, - ) - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=prompt_token_ids, - pooling_params=pooling_params) - - v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, - pooling_params=pooling_params) - assert_outputs_match(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, - pooling_params=pooling_params) - - v2_output = llm.encode( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - pooling_params=pooling_params, - ) - assert_outputs_match(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_pooling_params(llm: LLM): pooling_params = [ diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 6d0e66b0d1ab63a081f3bcb889b6563bcee3f8c1..9200f0a4f93dd107f76d432bd4caa3a43385cf5b 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -6,7 +6,7 @@ import os import pytest -from vllm import LLM, RequestOutput, SamplingParams +from vllm import LLM, SamplingParams from vllm.distributed import cleanup_dist_env_and_memory from ...utils import models_path_prefix @@ -43,50 +43,13 @@ def llm(): gpu_memory_utilization=0.10, enforce_eager=True) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]): - assert [o.outputs for o in o1] == [o.outputs for o in o2] - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) - - v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, - sampling_params=sampling_params) - - v2_output = llm.generate( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_sampling_params(llm: LLM): sampling_params = [ diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py index 361e2d0e1047fc6340ce71fe5ec46a76bc94f06f..2cee3c8d94e362281c503aa08a545a37789c29ba 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/llm/test_reward.py @@ -16,14 +16,6 @@ MODEL_NAME = "internlm/internlm2-1_8b-reward" prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -36,10 +28,9 @@ def llm(): trust_remote_code=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index dd4eae0ccc06e32dba06b0336afa8831ebdf537d..f715dacacb8ff22ddcff5dfc2b2d5518226cc59a 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -14,14 +14,6 @@ from ...models.utils import softmax MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -33,10 +25,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index d3ab1b3d5dfb2721c110742ac444a1550ea6b79e..1d1b8d0e7cbfc4e531725154400d549a5d3a25b1 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for HF_HUB_OFFLINE mode""" +import dataclasses import importlib import sys import os @@ -10,9 +11,9 @@ import urllib3 from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory +from vllm.engine.arg_utils import EngineArgs from ...utils import models_path_prefix - MODEL_CONFIGS = [ { "model": os.path.join(models_path_prefix, "facebook/opt-125m"), @@ -33,15 +34,16 @@ MODEL_CONFIGS = [ "tensor_parallel_size": 1, "tokenizer_mode": "mistral", }, - { - "model": "sentence-transformers/all-MiniLM-L12-v2", - "enforce_eager": True, - "gpu_memory_utilization": 0.20, - "max_model_len": 64, - "max_num_batched_tokens": 64, - "max_num_seqs": 64, - "tensor_parallel_size": 1, - }, + # TODO: re-enable once these tests are run with V1 + # { + # "model": "sentence-transformers/all-MiniLM-L12-v2", + # "enforce_eager": True, + # "gpu_memory_utilization": 0.20, + # "max_model_len": 64, + # "max_num_batched_tokens": 64, + # "max_num_seqs": 64, + # "tensor_parallel_size": 1, + # }, ] @@ -111,3 +113,36 @@ def _re_import_modules(): # Error this test if reloading a module failed if reload_exception is not None: raise reload_exception + + +@pytest.mark.skip_global_cleanup +@pytest.mark.usefixtures("cache_models") +def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch): + # Set HF to offline mode and ensure we can still construct an LLM + with monkeypatch.context() as m: + try: + m.setenv("HF_HUB_OFFLINE", "1") + m.setenv("VLLM_NO_USAGE_STATS", "1") + + def disable_connect(*args, **kwargs): + raise RuntimeError("No http calls allowed") + + m.setattr( + urllib3.connection.HTTPConnection, + "connect", + disable_connect, + ) + m.setattr( + urllib3.connection.HTTPSConnection, + "connect", + disable_connect, + ) + # Need to re-import huggingface_hub + # and friends to setup offline mode + _re_import_modules() + engine_args = EngineArgs(model="facebook/opt-125m") + LLM(**dataclasses.asdict(engine_args)) + finally: + # Reset the environment after the test + # NB: Assuming tests are run in online mode + _re_import_modules() diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 58195f98bd3510432c06370c39a64793b57819c5..0d0ce0be8c5f86e568f4d6640d06586867efd7f5 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -49,8 +49,7 @@ async def transcribe_audio(client, tokenizer, y, sr): return latency, num_output_tokens, transcription.text -async def bound_transcribe(model_name, sem, client, audio, reference): - tokenizer = AutoTokenizer.from_pretrained(model_name) +async def bound_transcribe(sem, client, tokenizer, audio, reference): # Use semaphore to limit concurrent requests. async with sem: result = await transcribe_audio(client, tokenizer, *audio) @@ -63,15 +62,19 @@ async def bound_transcribe(model_name, sem, client, audio, reference): async def process_dataset(model, client, data, concurrent_request): sem = asyncio.Semaphore(concurrent_request) + # Load tokenizer once outside the loop + tokenizer = AutoTokenizer.from_pretrained(model) + # Warmup call as the first `librosa.load` server-side is quite slow. audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] - _ = await bound_transcribe(model, sem, client, (audio, sr), "") + _ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "") tasks: list[asyncio.Task] = [] for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( - bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + bound_transcribe(sem, client, tokenizer, (audio, sr), + sample["text"])) tasks.append(task) return await asyncio.gather(*tasks) diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 30078fe90257a434e4f6fed09ce64bdaf84388f4..36c96d76c2e5f5f02060f1248706b0f71b038122 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -226,3 +226,33 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str): }, ) assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_score(server: RemoteOpenAIServer, model_name: str): + # score api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("score"), + json={ + "model": model_name, + "text_1": "ping", + "text_2": "pong", + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_rerank(server: RemoteOpenAIServer, model_name: str): + # rerank api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": "ping", + "documents": ["pong"], + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index b20838956d721c7a8c9a8b0af9999ce80949657e..9a1c0ea13b54f98746c92266622fee7162a1c60e 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -27,6 +27,28 @@ def serve_parser(): return make_arg_parser(parser) +### Test config parsing +def test_config_arg_parsing(serve_parser, cli_config_file): + args = serve_parser.parse_args([]) + assert args.port == 8000 + args = serve_parser.parse_args(['--config', cli_config_file]) + assert args.port == 12312 + args = serve_parser.parse_args([ + '--config', + cli_config_file, + '--port', + '9000', + ]) + assert args.port == 9000 + args = serve_parser.parse_args([ + '--port', + '9000', + '--config', + cli_config_file, + ]) + assert args.port == 9000 + + ### Tests for LoRA module parsing def test_valid_key_value_format(serve_parser): # Test old format: name=path diff --git a/tests/entrypoints/openai/test_collective_rpc.py b/tests/entrypoints/openai/test_collective_rpc.py new file mode 100644 index 0000000000000000000000000000000000000000..37c0b7a900ac4a5ed6696afa23e37a2e66a14e83 --- /dev/null +++ b/tests/entrypoints/openai/test_collective_rpc.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +class TestWorkerExtension: + + def get_model_name(self) -> str: + """Test non-pydantic return type.""" + return MODEL_NAME + + def echo_args_kwargs(self, *args, **kwargs) -> dict[str, Any]: + """Echo back both args and kwargs.""" + return dict( + args=list(args), + kwargs=kwargs, + total_items=len(args) + len(kwargs), + ) + + def return_none(self, *args, **kwargs) -> None: + """Test method that does not return anything""" + return + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--worker-extension-cls", + "tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension", + ] + with RemoteOpenAIServer( + MODEL_NAME, + args, + env_dict={ + "VLLM_SERVER_DEV_MODE": "1", + "CUDA_VISIBLE_DEVICES": "0" + }, + ) as remote_server: + yield remote_server + + +def test_get_model_name(server): + """Test basic response""" + response = requests.post(server.url_for("collective_rpc"), + json={"method": "get_model_name"}) + assert response.status_code == 200 + results = response.json() + assert "results" in results + assert results["results"] == [MODEL_NAME] + + +def test_return_none(server): + """Test return none""" + response = requests.post(server.url_for("collective_rpc"), + json={"method": "return_none"}) + assert response.status_code == 200 + results = response.json() + assert results["results"] == [None] + + +def test_echo_args_kwargs(server): + """Test args, kwargs, and dict response""" + args = ["arg1", "arg2"] + kwargs = {"key1": "value1", "key2": "value2"} + response = requests.post(server.url_for("collective_rpc"), + json={ + "method": "echo_args_kwargs", + "args": args, + "kwargs": kwargs + }) + assert response.status_code == 200 + results = response.json() + result = results["results"][0] + assert result["args"] == args + assert result["kwargs"] == kwargs + assert result["total_items"] == len(args) + len(kwargs) diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index a5b081f861074643c3f5fe8609eb530c226e6c10..4ef5d4e8a699a1c69cc4b8668411a4993e40db96 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen3-0.6B" +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "options": { + "$ref": "#/$defs/WeatherOptions", + "description": "Optional parameters for weather query", + }, + }, + "required": ["country", "unit"], + "$defs": { + "WeatherOptions": { + "title": "WeatherOptions", + "type": "object", + "additionalProperties": False, + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + "description": "Temperature unit", + "title": "Temperature Unit", + }, + "include_forecast": { + "type": "boolean", + "default": False, + "description": + "Whether to include a 24-hour forecast", + "title": "Include Forecast", + }, + "language": { + "type": "string", + "default": "zh-CN", + "description": "Language of the response", + "title": "Language", + "enum": ["zh-CN", "en-US", "ja-JP"], + }, + }, + }, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, +] + +messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, +] + @pytest.fixture(scope="module") def server(): # noqa: F811 @@ -27,6 +148,8 @@ def server(): # noqa: F811 "hermes", "--reasoning-parser", "qwen3", + "--gpu-memory-utilization", + "0.4" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -54,129 +177,6 @@ async def client(server): async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: Union[str, dict], enable_thinking: bool): - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - "options": { - "$ref": "#/$defs/WeatherOptions", - "description": - "Optional parameters for weather query", - }, - }, - "required": ["country", "unit"], - "$defs": { - "WeatherOptions": { - "title": "WeatherOptions", - "type": "object", - "additionalProperties": False, - "properties": { - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "default": "celsius", - "description": "Temperature unit", - "title": "Temperature Unit", - }, - "include_forecast": { - "type": "boolean", - "default": False, - "description": - "Whether to include a 24-hour forecast", - "title": "Include Forecast", - }, - "language": { - "type": "string", - "default": "zh-CN", - "description": "Language of the response", - "title": "Language", - "enum": ["zh-CN", "en-US", "ja-JP"], - }, - }, - }, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_forecast", - "description": "Get the weather forecast for a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["country", "days", "unit"], - }, - }, - }, - ] - - messages = [ - { - "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ - "forecast for the next 5 days, in fahrenheit?", - }, - ] if not stream: # Non-streaming test chat_completion = await client.chat.completions.create( @@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, output.extend(chunk.choices[0].delta.tool_calls) assert len(output) > 0 + + +@pytest.fixture(scope="module") +def k2_server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes", + "--reasoning-parser", + "qwen3", + "--gpu-memory-utilization", + "0.4", + ] + # hack to test kimi_k2 tool use tool_id format. + # avoid error in is_deepseek_mla check by setting kv_lora_rank=null + with RemoteOpenAIServer(MODEL_NAME, + args, + override_hf_configs={ + "model_type": 'kimi_k2', + 'kv_lora_rank': None + }) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def k2_client(k2_server): + async with k2_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("tool_choice", ["required"]) +async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, + stream: bool, tool_choice: str): + + if not stream: + # Non-streaming test + chat_completion = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice) + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + assert chat_completion.choices[0].message.tool_calls[ + 0].id == 'functions.get_current_weather:0' + else: + # Streaming test + output_stream = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice, + stream=True) + + output = [] + async for chunk in output_stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + for o in output: + assert o.id is None or o.id == 'functions.get_current_weather:0' diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index f9ec393873ea3698c14e46f8bd77b46e37ed402e..f2a70ecbc1fd38d0aab49144f20d73a39d4f8dea 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -26,14 +26,6 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = [ diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index f4801172580c63c2fed7d4b4abd203f2377ba38d..818efd825640c26d474be2845c16c86c8b16418b 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -47,6 +47,7 @@ class MockModelConfig: allowed_local_media_path: str = "" encoder_config = None generation_config: str = "auto" + skip_tokenizer_init: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index e071762772906f4cb3ce8b7dd19b036f9dffc8b3..13f3d1cc65673d536631ed45f8e33c90052b6d9b 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import asyncio import subprocess import sys import tempfile @@ -295,6 +295,99 @@ async def test_metrics_exist(server: RemoteOpenAIServer, assert metric in response.text +@pytest.mark.asyncio +async def test_abort_metrics_reset(server: RemoteOpenAIServer, + client: openai.AsyncClient, use_v1: bool): + + running_requests, waiting_requests, kv_cache_usage = ( + _get_running_metrics_from_api(server)) + + # Expect no running requests or kvcache usage + assert running_requests == 0 + assert waiting_requests == 0 + assert kv_cache_usage == 0.0 + + # Start some long-running requests that we can abort + tasks = [] + for _ in range(3): + task = asyncio.create_task( + client.completions.create( + model=MODEL_NAME, + prompt=_TOKENIZED_PROMPT, + max_tokens=100, # Long generation to give time to abort + temperature=0.0)) + tasks.append(task) + + # Wait a bit for requests to start processing + await asyncio.sleep(0.5) + + # Check that we have running requests + running_requests, waiting_requests, kv_cache_usage = ( + _get_running_metrics_from_api(server)) + + # Expect running requests and kvcache usage + assert running_requests > 0 + assert kv_cache_usage > 0 + + # Cancel all tasks to abort the requests + for task in tasks: + task.cancel() + + # Wait for cancellations to be processed + await asyncio.sleep(1.0) + + # Check that metrics have reset to zero + response = requests.get(server.url_for("metrics")) + assert response.status_code == HTTPStatus.OK + + # Verify running and waiting requests counts and KV cache usage are zero + running_requests_after, waiting_requests_after, kv_cache_usage_after = ( + _get_running_metrics_from_api(server)) + + assert running_requests_after == 0,\ + (f"Expected 0 running requests after abort, got " + f"{running_requests_after}") + assert waiting_requests_after == 0,\ + (f"Expected 0 waiting requests after abort, got " + f"{waiting_requests_after}") + assert kv_cache_usage_after == 0,\ + (f"Expected 0% KV cache usage after abort, got " + f"{kv_cache_usage_after}") + + +def _get_running_metrics_from_api(server: RemoteOpenAIServer): + """Return (running_count, waiting_count, kv_cache_usage)""" + + response = requests.get(server.url_for("metrics")) + assert response.status_code == HTTPStatus.OK + + # Verify running and waiting requests counts and KV cache usage are zero + running_requests, waiting_requests, kv_cache_usage = None, None, None + + for family in text_string_to_metric_families(response.text): + if family.name == "vllm:num_requests_running": + for sample in family.samples: + if sample.name == "vllm:num_requests_running": + running_requests = sample.value + break + elif family.name == "vllm:num_requests_waiting": + for sample in family.samples: + if sample.name == "vllm:num_requests_waiting": + waiting_requests = sample.value + break + elif family.name == "vllm:gpu_cache_usage_perc": + for sample in family.samples: + if sample.name == "vllm:gpu_cache_usage_perc": + kv_cache_usage = sample.value + break + + assert running_requests is not None + assert waiting_requests is not None + assert kv_cache_usage is not None + + return running_requests, waiting_requests, kv_cache_usage + + def test_metrics_exist_run_batch(use_v1: bool): input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 246bd014aa69011faedd9389c256fd00d6ee0367..11ed1c4a9ee4b9938e38908e6dfe8a5184ba4872 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -74,31 +74,44 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): -d '{"messages": [{"role": "assistant", "tool_calls": [{"custom": {"input": "", "name": ""}, "id": "", "type": "custom"}]}]}' \ http://localhost:8000/v1/chat/completions """ # noqa: E501 - if (hasattr(case, "body") and isinstance(case.body, dict) - and "messages" in case.body - and isinstance(case.body["messages"], list) - and len(case.body["messages"]) > 0): - - for message in case.body["messages"]: - if not isinstance(message, dict): - continue - - # Check for invalid file type in tokenize endpoint - if op.method.lower() == "post" and op.path == "/tokenize": - content = message.get("content", []) - if (isinstance(content, list) and len(content) > 0 and any( - item.get("type") == "file" for item in content)): - return False - - # Check for invalid tool_calls with non-function types - tool_calls = message.get("tool_calls", []) - if isinstance(tool_calls, list): - for tool_call in tool_calls: - if isinstance(tool_call, dict): - if tool_call.get("type") != "function": - return False - if "custom" in tool_call: - return False + if hasattr(case, "body") and isinstance(case.body, dict): + if ("messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0): + + for message in case.body["messages"]: + if not isinstance(message, dict): + continue + + # Check for invalid file type in tokenize endpoint + if op.method.lower() == "post" and op.path == "/tokenize": + content = message.get("content", []) + if (isinstance(content, list) and len(content) > 0 + and any( + item.get("type") == "file" + for item in content)): + return False + + # Check for invalid tool_calls with non-function types + tool_calls = message.get("tool_calls", []) + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if isinstance(tool_call, dict): + if tool_call.get("type") != "function": + return False + if "custom" in tool_call: + return False + + # Sometimes guided_grammar is generated to be empty + # Causing a server error in EBNF grammar parsing + # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421 + guided_grammar = case.body.get("guided_grammar") + + if guided_grammar == '': + # Allow None (will be handled as no grammar) + # But skip empty strings + return False + return True return strategy.filter(no_invalid_types) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 73364294cbcdc130a20ffe4bf9b375da0dcf4626..ce4d6c5f5d337725b3629caa2a6e941e9898f572 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -14,14 +14,6 @@ MODEL_NAME = "BAAI/bge-reranker-base" DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 1ca52599c519d574e82898efe61dd58e82665e4e..72d468db08f655b0b971e1c79aff41bfe99c6b2d 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -11,18 +11,25 @@ from openai import BadRequestError, NotFoundError, OpenAI from ...utils import RemoteOpenAIServer -pytest.skip(allow_module_level=True, reason="gpt-oss can't run on CI yet.") - MODEL_NAME = "openai/gpt-oss-20b" -DTYPE = "bfloat16" @pytest.fixture(scope="module") -def server(): +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def server(monkeypatch_module: pytest.MonkeyPatch): args = ["--enforce-eager", "--tool-server", "demo"] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest_asyncio.fixture @@ -269,10 +276,11 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_streaming(client: OpenAI, model_name: str): + # TODO: Add back when web search and code interpreter are available in CI prompts = [ "tell me a story about a cat in 20 words", - "What is 13 * 24? Use python to calculate the result.", - "When did Jensen found NVIDIA? Search it and answer the year only.", + # "What is 13 * 24? Use python to calculate the result.", + # "When did Jensen found NVIDIA? Search it and answer the year only.", ] for prompt in prompts: @@ -281,15 +289,15 @@ async def test_streaming(client: OpenAI, model_name: str): input=prompt, reasoning={"effort": "low"}, tools=[ - { - "type": "web_search_preview" - }, - { - "type": "code_interpreter", - "container": { - "type": "auto" - } - }, + # { + # "type": "web_search_preview" + # }, + # { + # "type": "code_interpreter", + # "container": { + # "type": "auto" + # } + # }, ], stream=True, ) @@ -317,6 +325,7 @@ async def test_streaming(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Web search tool is not available in CI yet.") async def test_web_search(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, @@ -331,6 +340,7 @@ async def test_web_search(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") async def test_code_interpreter(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, @@ -436,6 +446,7 @@ async def test_function_calling(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.flaky(reruns=5) async def test_function_calling_multi_turn(client: OpenAI, model_name: str): tools = [ { diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..6addcb41c4098e17d72a208823f94228b697885d --- /dev/null +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enforce-eager", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_basic_completion_with_emoji(server): + """Test basic completion with emoji to verify token_ids field.""" + async with server.get_async_client() as client: + # Test with return_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": True}, + ) + + # Check the raw response to see the structure + completion_dict = completion.model_dump() + + # Verify prompt_token_ids field is present in the completion response + assert "prompt_token_ids" in completion_dict["choices"][0] + assert isinstance(completion.choices[0].prompt_token_ids, list) + + # Check against the expected prompt token IDs + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + encoded_tokens = tokenizer.encode( + "Complete this sentence with emojis: I love coding 🚀") + # Check that encoded_tokens is a subsequence of prompt_token_ids + assert any(completion.choices[0].prompt_token_ids[i:i + + len(encoded_tokens)] + == encoded_tokens for i in range( + len(completion.choices[0].prompt_token_ids) - + len(encoded_tokens) + 1)) + + # Verify token_ids field is present in the choice + assert completion.choices[0].token_ids is not None + assert isinstance(completion.choices[0].token_ids, list) + assert len(completion.choices[0].token_ids) > 0 + + # Verify decoding works correctly + decoded_text = tokenizer.decode(completion.choices[0].token_ids) + # The decoded text should contain a <|im_end|> at the end + assert decoded_text.startswith(completion.choices[0].text) + + # Test without return_token_ids (should be None) + completion_without = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": False}, + ) + + completion_without_dict = completion_without.model_dump() + assert completion_without_dict["choices"][0].get("token_ids") is None + assert completion_without_dict.get("prompt_token_ids") is None + + +@pytest.mark.asyncio +async def test_chat_completion_with_tool_use(server): + """Test chat completion with tool use (get_weather function).""" + tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": + "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature", + }, + }, + "required": ["location"], + }, + }, + }] + + async with server.get_async_client() as client: + # Test with return_token_ids enabled + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids field is present in choices + assert response.choices[0].token_ids is not None + assert isinstance(response.choices[0].token_ids, list) + + # Verify prompt_token_ids field is present + assert response.prompt_token_ids is not None + assert isinstance(response.prompt_token_ids, list) + + # Verify the prompt texts and response texts + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + prompt_text = tokenizer.decode(response.prompt_token_ids) + assert prompt_text.startswith( + "<|im_start|>system\nYou are a helpful assistant.") + assert prompt_text.endswith( + "What's the weather like in Paris?<|im_end|>\n" + "<|im_start|>assistant\n") + + response_text = tokenizer.decode(response.choices[0].token_ids) + assert response_text.startswith('<tool_call>\n{"name": "get_weather"') + assert response_text.endswith("</tool_call><|im_end|>") + + # If tool call was made, verify the response structure + if response.choices[0].message.tool_calls: + assert len(response.choices[0].message.tool_calls) > 0 + tool_call = response.choices[0].message.tool_calls[0] + assert tool_call.function.name == "get_weather" + + # Test without return_token_ids + response_without = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": False}, + ) + + assert response_without.choices[0].token_ids is None + assert response_without.prompt_token_ids is None + + +@pytest.mark.asyncio +async def test_comparison_with_prompt_logprobs_and_logprobs(server): + """ + Test that token_ids align with prompt_logprobs and + logprobs when return_tokens_as_token_ids is enabled. + """ + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Hello, world! How are you today?", + max_tokens=20, + temperature=0, + echo=True, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True, + "prompt_logprobs": 1 + }, + ) + + # Verify all fields are present + assert completion.choices[0].token_ids is not None + assert completion.choices[0].prompt_token_ids is not None + assert completion.choices[0].prompt_logprobs is not None + assert completion.choices[0].logprobs is not None + + # Extract token IDs from logprobs + # (when return_tokens_as_token_ids is True) + logprobs_token_ids = [] + for token_str in completion.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + logprobs_token_ids.append(token_id) + + # When echo=True, the logprobs include both prompt and response tokens + # The token_ids field should match the the suffix of response portion + # The prompt_token_ids should match the prompt portion + assert len(completion.choices[0].token_ids) < len(logprobs_token_ids) + response_token_ids_length = len(completion.choices[0].token_ids) + assert logprobs_token_ids[-response_token_ids_length:] == \ + completion.choices[0].token_ids + + # Verify tokenizer consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode( + completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert "Hello, world" in prompt_text + + # Decode response tokens + if completion.choices[0].token_ids: + response_text = tokenizer.decode(completion.choices[0].token_ids) + assert completion.choices[0].text.endswith(response_text) + + # Test streaming mode + stream = await client.completions.create( + model=MODEL_NAME, + prompt="Tell me a short fact about Python:", + max_tokens=30, + temperature=0, + stream=True, + echo=False, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True + }, + ) + + # Collect streamed tokens + streamed_prompt_token_ids = [] + streamed_token_ids = [] + streamed_logprob_token_ids = [] + first_chunk = True + async for chunk in stream: + for token_str in chunk.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + streamed_logprob_token_ids.append(token_id) + if first_chunk: + streamed_prompt_token_ids = chunk.choices[0].prompt_token_ids + first_chunk = False + streamed_token_ids += chunk.choices[0].token_ids + + # Verify we collected some tokens and first chunk had prompt_token_ids + assert len(streamed_prompt_token_ids) > 0 + assert streamed_token_ids == streamed_logprob_token_ids + + +@pytest.mark.asyncio +async def test_chat_completion_with_emoji_and_token_ids(server): + """Test chat completion with emojis to verify token_ids handling.""" + chat_messages = [ + { + "role": "system", + "content": "You like to use emojis in your responses." + }, + { + "role": "user", + "content": "Repeat after me: I love cats 🐱" + }, + ] + async with server.get_async_client() as client: + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids are present + response_dict = response.model_dump() + assert response.choices[0].token_ids is not None + assert "prompt_token_ids" in response_dict + + # Verify the response contains the expected fields + assert response.choices[0].message.content is not None + + # Decode token_ids and verify consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + decoded_prompt = tokenizer.decode(response.prompt_token_ids) + assert decoded_prompt.startswith( + "<|im_start|>system\nYou like to use emojis in your responses.") + assert decoded_prompt.endswith( + "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n") + + decoded_response = tokenizer.decode(response.choices[0].token_ids) + # The content should match the response text + # except the ending <|im_end|> + assert decoded_response == response.choices[ + 0].message.content + "<|im_end|>" + + # Test with streaming + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + stream=True, + extra_body={"return_token_ids": True}, + ) + + collected_content = "" + collected_token_ids = [] + first_chunk = True + + async for chunk in stream: + if first_chunk: + assert chunk.prompt_token_ids is not None + assert isinstance(chunk.prompt_token_ids, list) + # Check the prompt_token_ids match the initial prompt + decoded_prompt_stream = tokenizer.decode( + chunk.prompt_token_ids) + assert decoded_prompt_stream == decoded_prompt + first_chunk = False + else: + chunk_dump = chunk.model_dump() + assert "prompt_token_ids" not in chunk_dump, \ + "Subsequent chunks should not have prompt_token_ids" + + if chunk.choices: + if chunk.choices[0].delta.content: + collected_content += chunk.choices[0].delta.content + # token_ids may not present in all chunks + choice_dump = chunk.choices[0].model_dump() + if "token_ids" in choice_dump: + collected_token_ids.extend(chunk.choices[0].token_ids) + + # Verify we got response and token_ids + assert len(collected_content) > 0 + assert len(collected_token_ids) > 0 + + # Verify token_ids decode properly + decoded_response = tokenizer.decode(collected_token_ids) + assert decoded_response == collected_content + "<|im_end|>" diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index f698b492f3be6e12c1b15d200ae02e3474593513..8695a56eb93f4987a60eaa5e1d1a87b18b3be438 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -14,15 +14,6 @@ from vllm.entrypoints.openai.protocol import ScoreResponse from ...utils import RemoteOpenAIServer, models_path_prefix - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - MODELS = [ { "name": os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-m3"), diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index cecbc7523ee525fb02b1c840df4fb663658307b4..86e077c15bec68e774f7cb8d6ab014dc27959714 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -284,9 +284,11 @@ async def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 +@pytest.mark.parametrize("model_type", ["gpt_oss", "any"]) @pytest.mark.asyncio -async def test_serving_chat_did_set_correct_cache_salt(): +async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() + mock_model_config.hf_config.model_type = model_type mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/test_token_in_token_out.py new file mode 100644 index 0000000000000000000000000000000000000000..ed003939c44bebed2e8ddc7f099b1c94a7905c9f --- /dev/null +++ b/tests/entrypoints/openai/test_token_in_token_out.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import tempfile + +import pytest + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf) +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b") + + +@pytest.fixture(scope="module") +def server(): + global MODEL_PATH + MODEL_PATH = download_weights_from_hf( + MODEL_NAME, + allow_patterns=["*"], + cache_dir=MODEL_PATH, + ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"]) + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + "--skip-tokenizer-init", + "--load-format", + "dummy", + ] + with RemoteOpenAIServer(MODEL_PATH, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_token_in_token_out_and_logprobs(server): + """ + Test token-in-token-out and token_ids align with prompt_logprobs + & logprobs when return_tokens_as_token_ids is enabled. + """ + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + text = "Hello, world! How are you today?" + token_ids = tokenizer.encode(text) + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_PATH, + prompt=token_ids, + max_tokens=20, + temperature=0, + echo=True, + extra_body={ + "return_token_ids": True, + }, + ) + + # Verify all fields are present + assert (completion.choices[0].token_ids is not None + and 0 < len(completion.choices[0].token_ids) <= 20) + assert completion.choices[0].prompt_token_ids is not None + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode( + completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert prompt_text == text diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 93239f41a4aeb5c32b9a6751a68c3893048bbfb3..6009d9aeec935b59ec647f93dbb2f8d9674f0fc0 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -69,8 +69,11 @@ async def test_basic_audio(mary_had_lamb, model_name): language="en", response_format="text", temperature=0.0) - out = json.loads(transcription)['text'] - assert "Mary had a little lamb," in out + out = json.loads(transcription) + out_text = out['text'] + out_usage = out['usage'] + assert "Mary had a little lamb," in out_text + assert out_usage["seconds"] == 16, out_usage["seconds"] @pytest.mark.asyncio @@ -116,9 +119,12 @@ async def test_long_audio_request(mary_had_lamb, client): language="en", response_format="text", temperature=0.0) - out = json.loads(transcription)['text'] - counts = out.count("Mary had a little lamb") + out = json.loads(transcription) + out_text = out['text'] + out_usage = out['usage'] + counts = out_text.count("Mary had a little lamb") assert counts == 10, counts + assert out_usage["seconds"] == 161, out_usage["seconds"] @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py index 79b6ce059ce49b820ae9116f438ed5ef58e0dab0..121c0413e1af78b7c69a9c50e458db814fa2471a 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/openai/test_truncation.py @@ -64,6 +64,28 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): assert response["usage"]["prompt_tokens"] == truncation_size +@pytest.mark.asyncio +async def test_zero_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 0 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + with pytest.raises(openai.BadRequestError) as err: + await client.post(path="embeddings", cast_to=object, body={**kwargs}) + + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + + assert error_details["type"] == "BadRequestError" + assert "This model's maximum context length is" in error_details["message"] + assert "tokens in the input for embedding generation" in error_details[ + "message"] + assert "Please reduce the length of the input" in error_details["message"] + + @pytest.mark.asyncio async def test_bigger_truncation_size(client: openai.AsyncOpenAI): truncation_size = max_model_len + 1 @@ -74,18 +96,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): } with pytest.raises(openai.BadRequestError) as err: - err = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) - - assert str(err) == f"""openai.BadRequestError: - Error code: 400 - {{'object': 'error', - 'message': 'truncate_prompt_tokens value - ({truncation_size}) - is greater than max_model_len ({max_model_len}). - Please, select a smaller truncation size.', - 'type': 'BadRequestError', - 'param': None, 'code': 400}}""" + await client.post(path="embeddings", cast_to=object, body={**kwargs}) + + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + assert error_details["type"] == "BadRequestError" + expected_message = ("truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + assert error_details["message"] == expected_message @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 1cb5901fc95134f60b905ce1abb193c4d51d8ca7..38a4e54c4c52df6b08cca622d850bc629e920831 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -7,8 +7,6 @@ import openai import pytest import os import pytest_asyncio -import requests -from PIL import Image from transformers import AutoProcessor from vllm.multimodal.utils import encode_image_base64, fetch_image @@ -97,7 +95,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): "role": "user", "content": f"{placeholder}{content}", }] - images = [Image.open(requests.get(image_url, stream=True).raw)] + images = [fetch_image(image_url)] prompt = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 987eaf47ce440d5b21aecca3db5e4d2b4ef910ff..d93ae61656535c0623ea3d1bcdffe485487c26a4 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -6,7 +6,6 @@ import json import pytest import requests -from PIL import Image from transformers import AutoProcessor from vllm.entrypoints.openai.protocol import EmbeddingResponse @@ -72,7 +71,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" - images = [Image.open(requests.get(image_url, stream=True).raw)] + images = [fetch_image(image_url)] inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md new file mode 100644 index 0000000000000000000000000000000000000000..58572c3a6fbc1b728729f299f5cbf7154ae1213c --- /dev/null +++ b/tests/evals/gsm8k/README.md @@ -0,0 +1,35 @@ +# GSM8K Accuracy Evaluation + +This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control. + +## Usage + +### Run tests with pytest (like buildkite) + +```bash +pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +``` + +### Run standalone evaluation script + +```bash +# Start vLLM server first +vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 + +# Run evaluation +python tests/gsm8k/gsm8k_eval.py --port 8000 +``` + +## Configuration Format + +Model configs in `configs/` directory use this YAML format: + +```yaml +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +accuracy_threshold: 0.54 # Minimum expected accuracy +num_questions: 1319 # Number of questions (default: full test set) +num_fewshot: 5 # Few-shot examples from train set +max_model_len: 4096 # Model context length +``` diff --git a/tests/evals/gsm8k/__init__.py b/tests/evals/gsm8k/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fec1fe5bcdfd0b2603420e9ac87122fc32bda1a --- /dev/null +++ b/tests/evals/gsm8k/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caa0448f23d48a0544ea5ceac008850bb555e92d --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" +accuracy_threshold: 0.74 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml new file mode 100644 index 0000000000000000000000000000000000000000..615aa69a2d2b6aed88a3e328db644de627a07681 --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8" +accuracy_threshold: 0.31 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5dbceeeb2b452b77b40e00cef02a237ffd66040 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" +accuracy_threshold: 0.45 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5319ada30f645020a462bfd72e4c485ab823be3f --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +accuracy_threshold: 0.60 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c39fb979d98ac0b1d7a3f94b7bea0c8cfc29cdd4 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-0.6B-FP8" +accuracy_threshold: 0.375 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/models-small.txt b/tests/evals/gsm8k/configs/models-small.txt new file mode 100644 index 0000000000000000000000000000000000000000..afd1065b9191b75df79f5a1b10a00c8eb1d25c40 --- /dev/null +++ b/tests/evals/gsm8k/configs/models-small.txt @@ -0,0 +1,5 @@ +Qwen3-0.6B-FP8.yaml +Llama-3.2-1B-Instruct-INT8-CT.yaml +Llama-3-8B-Instruct-nonuniform-CT.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +Qwen1.5-MoE-W4A16-CT.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..d96b0a66ede2b18313dd6036de7ee5bf4a231313 --- /dev/null +++ b/tests/evals/gsm8k/conftest.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + + +def pytest_addoption(parser): + """Add custom command line options.""" + parser.addoption("--config-list-file", + default="configs/models-small.txt", + help="File containing list of config files to test") + parser.addoption("--tp-size", + default=1, + type=int, + help="Tensor parallel size") + + +def pytest_generate_tests(metafunc): + """Generate test parameters from config files.""" + if "config_filename" in metafunc.fixturenames: + config_list_file = metafunc.config.getoption("--config-list-file") + tp_size = metafunc.config.getoption("--tp-size") + + # Handle both relative and absolute paths + config_list_path = Path(config_list_file) + if not config_list_path.is_absolute(): + # If relative, try relative to test directory first + test_dir_path = Path(__file__).parent / config_list_file + if test_dir_path.exists(): + config_list_path = test_dir_path + else: + # Try relative to current working directory + config_list_path = Path.cwd() / config_list_file + + print(f"Looking for config list at: {config_list_path}") + + config_files = [] + if config_list_path.exists(): + # Determine config directory (same directory as the list file) + config_dir = config_list_path.parent + + with open(config_list_path) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + config_path = config_dir / line + print(f"Checking config file: {config_path}") + if config_path.exists(): + config_files.append(config_path) + print(f" ✓ Found: {config_path}") + else: + print(f" ✗ Missing: {config_path}") + else: + print(f"Config list file not found: {config_list_path}") + + # Generate test parameters + if config_files: + metafunc.parametrize(["config_filename", "tp_size"], + [(config_file, int(tp_size)) + for config_file in config_files], + ids=[ + f"{config_file.stem}-tp{tp_size}" + for config_file in config_files + ]) + else: + print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0ce25f75dd42dc66597a3b0b684b350844a3ac --- /dev/null +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Isolated GSM8K evaluation script for vLLM serve endpoint. +""" + +import argparse +import ast +import asyncio +import json +import os +import time +from collections.abc import Generator +from typing import Optional, Union + +import aiohttp +import numpy as np +import regex as re +import requests +from tqdm.asyncio import tqdm + +INVALID = -9999999 + + +def download_and_cache_file(url: str, filename: Optional[str] = None) -> str: + """Download and cache a file from a URL.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=1024): + f.write(chunk) + + return filename + + +def load_gsm8k_data() -> tuple[list[dict], list[dict]]: + """Load GSM8K train and test data""" + train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl" + test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + + train_file = download_and_cache_file(train_url) + test_file = download_and_cache_file(test_url) + + train_data = list(read_jsonl(train_file)) + test_data = list(read_jsonl(test_file)) + + return train_data, test_data + + +def read_jsonl(filename: str) -> Generator[dict, None, None]: + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if not line.startswith("#"): + yield json.loads(line) + + +def get_answer_value(answer_str: str) -> int: + """Extract the numerical answer from the response.""" + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +async def call_vllm_api(session: aiohttp.ClientSession, + prompt: str, + temperature: float, + max_tokens: int, + stop: Optional[list[str]] = None, + url: Optional[str] = None, + seed: Optional[int] = None) -> str: + """Call vLLM's OpenAI-compatible completions endpoint.""" + data = { + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + "stop": stop, + } + if seed is not None: + data["seed"] = seed + + try: + async with session.post(f"{url}/v1/completions", + json=data) as response: + response.raise_for_status() + result = await response.json() + return result["choices"][0]["text"] + except Exception as e: + print(f"Error calling vLLM API: {e}") + return "" + + +def evaluate_gsm8k(num_questions: int = 1319, + num_shots: int = 5, + max_tokens: int = 256, + host: str = "http://127.0.0.1", + port: int = 8000, + temperature: float = 0.0, + seed: Optional[int] = 42) -> dict[str, Union[float, int]]: + """ + Evaluate GSM8K accuracy using vLLM serve endpoint. + + Returns dict with accuracy, invalid_rate, latency, etc. + """ + base_url = f"{host}:{port}" + + # Load GSM8K train and test data + train_data, test_data = load_gsm8k_data() + + # Limit to available test questions + num_questions = min(num_questions, len(test_data)) + + # Build few-shot examples from train split (like lm-eval does) + few_shot_examples = "" + for i in range(num_shots): + few_shot_examples += (f"Question: {train_data[i]['question']}\n" + f"Answer: {train_data[i]['answer']}\n\n") + + # Prepare test questions and labels from test split + questions = [] + labels = [] + for i in range(num_questions): + questions.append(f"Question: {test_data[i]['question']}\nAnswer:") + labels.append(get_answer_value(test_data[i]["answer"])) + + assert all(label != INVALID for label in labels), "Some labels are invalid" + + # Run evaluation + async def run_async_evaluation(): + states: list[str] = [""] * num_questions + + async def get_answer(session: aiohttp.ClientSession, i: int) -> str: + prompt = few_shot_examples + questions[i] + answer = await call_vllm_api( + session=session, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + stop=["Question", "Assistant:", "<|separator|>"], + url=base_url, + seed=seed, + ) + states[i] = answer + return answer + + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=600)) as session: + tasks = [get_answer(session, i) for i in range(num_questions)] + await tqdm.gather(*tasks, desc="Evaluating") + + return states + + print(f"Running GSM8K evaluation: {num_questions} questions, " + f"{num_shots}-shot") + + tic = time.perf_counter() + states = asyncio.run(run_async_evaluation()) + latency = time.perf_counter() - tic + + # Compute metrics + preds = [get_answer_value(state) for state in states] + accuracy = np.mean(np.array(preds) == np.array(labels)) + invalid_rate = np.mean(np.array(preds) == INVALID) + + result = { + "accuracy": accuracy, + "invalid_rate": invalid_rate, + "latency": latency, + "questions_per_second": num_questions / latency, + "num_questions": num_questions, + "num_shots": num_shots, + "max_tokens": max_tokens, + "timestamp": time.time(), + } + + return result + + +def main() -> None: + parser = argparse.ArgumentParser( + description="GSM8K evaluation for vLLM serve") + parser.add_argument("--num-shots", + type=int, + default=5, + help="Number of few-shot examples") + parser.add_argument("--num-questions", + type=int, + default=1319, + help="Number of questions to evaluate") + parser.add_argument("--max-tokens", + type=int, + default=256, + help="Max tokens for generation") + parser.add_argument("--host", + type=str, + default="http://127.0.0.1", + help="Host URL") + parser.add_argument("--port", type=int, default=8000, help="Port number") + parser.add_argument("--temperature", + type=float, + default=0.0, + help="Temperature for generation") + parser.add_argument("--seed", + type=int, + default=42, + help="Random seed for reproducibility") + parser.add_argument("--save-results", + type=str, + help="Save results to JSON file") + + args = parser.parse_args() + + result = evaluate_gsm8k( + num_questions=args.num_questions, + num_shots=args.num_shots, + max_tokens=args.max_tokens, + host=args.host, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + # Print results to terminal + print("\nResults:") + print(f"Accuracy: {result['accuracy']:.3f}") + print(f"Invalid responses: {result['invalid_rate']:.3f}") + print(f"Total latency: {result['latency']:.3f} s") + print(f"Questions per second: {result['questions_per_second']:.3f}") + + # Optional file saving + if args.save_results: + with open(args.save_results, "w") as f: + json.dump(result, f, indent=2) + print(f"Results saved to {args.save_results}") + + +if __name__ == "__main__": + main() diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py new file mode 100644 index 0000000000000000000000000000000000000000..a12dd49dbea6d7b0f345500b374e483e9060b2d3 --- /dev/null +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GSM8K evaluation using vLLM server and isolated GSM8K script. +Replacement for lm-eval-harness with better performance and control. + +Usage: +pytest -s -v test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +""" + +import yaml + +from tests.utils import RemoteOpenAIServer + +from .gsm8k_eval import evaluate_gsm8k + +RTOL = 0.08 # Relative tolerance for accuracy comparison + + +def launch_gsm8k_eval(eval_config, server_url, tp_size): + """Launch GSM8K evaluation using our isolated script.""" + # Extract host and port from server URL + if "://" in server_url: + server_url = server_url.split("://")[1] + + host_port = server_url.split("/")[0] # Remove path if present + if ":" in host_port: + host, port = host_port.split(":") + port = int(port) + else: + host = host_port + port = 8000 + + # Add http:// prefix if not present + if not host.startswith("http"): + host = f"http://{host}" + + # Run GSM8K evaluation + results = evaluate_gsm8k( + num_questions=eval_config["num_questions"], + num_shots=eval_config["num_fewshot"], + host=host, + port=port, + ) + + return results + + +def test_gsm8k_correctness_param(config_filename, tp_size): + """Test GSM8K correctness for a given model configuration.""" + eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) + + # Server arguments + server_args = [ + "--max-model-len", + str(eval_config.get("max_model_len", 4096)), + "--enforce-eager", + "--trust-remote-code", + "--tensor-parallel-size", + str(tp_size), + ] + + # Launch server and run evaluation + with RemoteOpenAIServer(eval_config["model_name"], + server_args, + max_wait_seconds=480) as remote_server: + server_url = remote_server.url_for("v1") + + results = launch_gsm8k_eval(eval_config, server_url, tp_size) + + # Check accuracy against threshold + measured_accuracy = results["accuracy"] + expected_accuracy = eval_config["accuracy_threshold"] + + print(f"GSM8K Results for {eval_config['model_name']}:") + print(f" Accuracy: {measured_accuracy:.3f}") + print(f" Expected: {expected_accuracy:.3f}") + print(f" Questions: {results['num_questions']}") + print(f" Invalid rate: {results['invalid_rate']:.3f}") + print(f" Latency: {results['latency']:.1f}s") + print(f" QPS: {results['questions_per_second']:.1f}") + + # Verify accuracy is within tolerance + assert measured_accuracy >= expected_accuracy - RTOL, ( + f"Accuracy too low: {measured_accuracy:.3f} < " + f"{expected_accuracy:.3f} - {RTOL:.3f}") + + print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index b86f6006590304a32bd621efbc1eb73bc63f84b6..6ea57e1d9a089c091679a237f3d2f9c97a5d47a2 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -705,6 +705,94 @@ def test_swap_blocks_mla( f"{dst} in dst_cache.") +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("qk_rope_head_dim", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_blocks", [1024]) +@pytest.mark.parametrize("max_seq_len", [512]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, + block_size, num_blocks, + max_seq_len, batch_size, dtype, + kv_cache_dtype, device): + entry_size = kv_lora_rank + qk_rope_head_dim + scale = torch.tensor(0.1, dtype=torch.float32, device=device) + src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device) + _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) + + seq_len_tensor = torch.randint(0, + max_seq_len + 1, (batch_size, ), + device=device) + + total_tokens = seq_len_tensor.sum() + cu_seq_lens = torch.empty((batch_size + 1), + dtype=torch.int32, + device=device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + print("seq_len_tensor", seq_len_tensor) + + tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size + block_table = torch.empty((batch_size, num_blocks), + dtype=torch.int32, + device=device) + + for b in range(batch_size): + perm = torch.randperm(num_blocks, device=device) + block_table[b, :] = perm + + dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device) + + expected_batches = [] + for b in range(batch_size): + s = seq_len_tensor[b] + if s == 0: + continue + tot = tot_blocks_tensor[b] + blocks = block_table[b, :tot].tolist() + + gathered_rows = [] + for i in range(tot - 1): + block_data = src_cache[blocks[i]] + if kv_cache_dtype == "fp8": + dequantized_block = torch.empty_like(block_data, dtype=dtype) + ops.convert_fp8(dequantized_block, block_data, scale.item()) + gathered_rows.append(dequantized_block) + else: + gathered_rows.append(block_data) + remaining = s - (tot - 1) * block_size + last_block_data = src_cache[blocks[-1], :remaining, :] + if kv_cache_dtype == "fp8": + dequantized_last_block = torch.empty_like(last_block_data, + dtype=dtype) + ops.convert_fp8(dequantized_last_block, last_block_data, + scale.item()) + gathered_rows.append(dequantized_last_block) + else: + gathered_rows.append(last_block_data) + + batch_expected = torch.cat(gathered_rows, dim=0) + expected_batches.append(batch_expected) + expected = torch.cat(expected_batches, dim=0) + + opcheck( + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, None), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, kv_cache_dtype, + scale, None) + torch.testing.assert_close(dst, expected) + + @pytest.mark.parametrize("kv_lora_rank", [512]) @pytest.mark.parametrize("qk_rope_head_dim", [64]) @pytest.mark.parametrize("block_size", [16]) @@ -716,9 +804,9 @@ def test_swap_blocks_mla( ["auto"]) # You can also test "fp8" if needed. @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, - num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, + num_blocks, max_seq_len, batch_size, dtype, + kv_cache_dtype, device): entry_size = kv_lora_rank + qk_rope_head_dim src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device) @@ -768,12 +856,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, expected = torch.cat(expected_batches, dim=0) opcheck( - torch.ops._C_cache_ops.gather_cache, + torch.ops._C_cache_ops.cp_gather_cache, (src_cache, dst, block_table, cu_seq_lens, batch_size, None), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) torch.testing.assert_close(dst, expected) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 4b84e6a00eceb9f3fca997cece6fad036b0b7ae8..8d0a11d8eb8ab8bb82478e6cf95a606df9efed03 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -6,28 +6,19 @@ import flashinfer import pytest import torch +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm.platforms import current_platform +from vllm.utils import round_up if not current_platform.is_device_capability(100): pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - -MAX_Q_LEN = 1024 -MAX_KV_LEN = 4096 -BATCH_SIZES = [4, 12] -NUM_HEADS = [(16, 16), (40, 8)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -KV_LAYOUTS = ["HND"] -DTYPES = [torch.bfloat16] -KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. -SOFT_CAPS = [None, 50.0] +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype), scale.float().reciprocal() -@pytest.mark.parametrize("batch_size", BATCH_SIZES) +DTYPE = [torch.bfloat16] +QUANT_DTYPES = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), +] +BATCH_SIZE = [4, 12] +MAX_SEQ_LENS = [(1024, 4096)] +NUM_HEADS = [(64, 8), (40, 8)] +HEAD_SIZE = [128] +KV_LAYOUT = ["HND"] # currently only HND is supported +BLOCK_SIZE = [16] +SOFT_CAP = [None, 50.0] + +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], batch_size: int, + max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, - block_size: int, kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], + block_size: int, soft_cap: Optional[float], ) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - torch.set_default_device("cuda") current_platform.seed_everything(0) - kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - max_kv_len = torch.max(kv_lens).item() - num_seqs = len(kv_lens) + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 + _, max_kv_len = max_seq_lens - scale = head_size**-0.5 + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + sm_scale = float(1.0 / (head_size**0.5)) kv_cache_shape = None if kv_layout == "NHD": @@ -83,23 +93,40 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint(0, NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), + (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - k_scale = v_scale = kv_scale kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] + for i in range(batch_size): + seq_len = seq_lens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) @@ -112,103 +139,120 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Decode wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_query_heads // num_kv_heads) > 4)) + workspace_buffer, kv_layout, use_tensor_cores=True) wrapper.plan(kv_indptr, kv_indices, kv_last_page_lens, - num_query_heads, + num_qo_heads, num_kv_heads, head_size, block_size, "NONE", - sm_scale=scale, + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache_dtype, + kv_data_type=dtype, logits_soft_cap=soft_cap) - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + o_sf_scale = None + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) # TRTLLM Decode - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - output_trtllm = torch.empty(query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, + query=query, + kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=kv_lens_tensor, - max_seq_len=max_kv_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, out=output_trtllm, ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 3e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 2e-2 - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" -@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) @pytest.mark.parametrize("soft_cap", [None]) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], batch_size: int, + max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, - block_size: int, kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], + block_size: int, soft_cap: Optional[float], ) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if dtype != kv_cache_dtype: - pytest.skip(f"Not supported dtype({dtype}) with " - "kv_cache_dtype({kv_cache_dtype})") - torch.set_default_device("cuda") current_platform.seed_everything(0) - q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) - q_lens[-1] = MAX_Q_LEN - max_q_len = torch.max(q_lens).item() - q_indptr = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(q_lens, dim=0, dtype=torch.int32), - ]) + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN + if q_quant_dtype != kv_quant_dtype: + pytest.skip("Skipped mixed QKV dtypes for prefill") - seq_lens = kv_lens + q_lens - max_seq_len = torch.max(seq_lens).item() - num_seqs = len(seq_lens) + max_q_len, max_kv_len = max_seq_lens - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - scale = head_size**-0.5 - - query = torch.randn(torch.sum(q_lens).item(), - num_query_heads, - head_size, - dtype=dtype) + sm_scale = float(1.0 / (head_size**0.5)) kv_cache_shape = None if kv_layout == "NHD": @@ -217,22 +261,49 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) + + q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32) + q_lens[-1] = max_q_len + q_indptr = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ]) + + query = torch.randn(torch.sum(q_lens).item(), + num_qo_heads, + head_size, + dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint(0, NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), + (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - k_scale = v_scale = kv_scale kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size @@ -246,48 +317,81 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Prefill wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout) wrapper.plan(q_indptr, kv_indptr, kv_indices, kv_last_page_lens, - num_query_heads, + num_qo_heads, num_kv_heads, head_size, block_size, causal=True, - sm_scale=scale, + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache_dtype, + kv_data_type=dtype, logits_soft_cap=soft_cap) - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + o_sf_scale = None + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) + + # TRTLLM Prefill + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) - # TRTLLM Decode - output_trtllm = torch.empty(query.shape, dtype=dtype) flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, + query=query, + kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 4e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 91fcdffe2c7419c55749c5e3b9115d1a24157ab4..9093adc02e67bf0fca6f79b44cf46a096a38f8c4 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) - assert cos_diff < 1e-4 + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-4 #1e-5 FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ if not is_flashmla_supported()[0] else "FlashMLA is supported" @@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) -@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @pytest.mark.parametrize("h_q", [16, 32, 64, 128]) @pytest.mark.parametrize("h_kv", [1]) @pytest.mark.parametrize("d", [576]) @@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("torch_dtype", + [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen, dtype): + varlen, torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}") + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv) + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None + def flash_mla(): return flash_mla_with_kvcache( q, @@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def scaled_dot_product_attention(query, key, value, is_causal=False): @@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, return attn_weight @ value, lse def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - ref_O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), is_causal=causal, ) - out[i] = ref_O.transpose(0, 1) - lse[i] = LSE + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i return out, lse out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + - b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " - f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/tests/kernels/attention/untest_attention_selector.py b/tests/kernels/attention/untest_attention_selector.py index af74f1086e28a9ca79dcd37c8c6fab89ab7b54a7..a360069ed0684a2940dc27d90b01b5df86c75b04 100644 --- a/tests/kernels/attention/untest_attention_selector.py +++ b/tests/kernels/attention/untest_attention_selector.py @@ -80,6 +80,9 @@ def test_env( m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") + if name == "FLASHINFER" and not use_v1: + pytest.skip("FlashInfer backend is only available on V1 engine") + if device == "cpu": with patch("vllm.attention.selector.current_platform", CpuPlatform()): diff --git a/tests/kernels/attention/untest_flashinfer.py b/tests/kernels/attention/untest_flashinfer.py index 2c8e4f51a3f295c2ded42915502f04c83d9bb750..ca57597d423d50e5f4b577509ed4faa18f1190c4 100644 --- a/tests/kernels/attention/untest_flashinfer.py +++ b/tests/kernels/attention/untest_flashinfer.py @@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv( workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) + use_tensor_cores=True) wrapper.plan( kv_indptr, kv_indices, @@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 - use_tensor_cores = (num_query_heads // num_kv_heads) > 4 + use_tensor_cores = True kv_cache_dtype = torch.float8_e4m3fn query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/untest_mamba_ssm_ssd.py similarity index 100% rename from tests/kernels/mamba/test_mamba_ssm_ssd.py rename to tests/kernels/mamba/untest_mamba_ssm_ssd.py diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 6f95581a5e60d0373816aecd49a1384672b2b7f1..36a98522a658805512323d9175e625b5a7b5ca8e 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch from .utils import make_test_weights @@ -370,9 +370,10 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): @@ -427,9 +428,10 @@ USE_FP8_DISPATCH = [False] @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @pytest.mark.parametrize("block_size", [[128, 128]]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 43804c410b6c284fadb1c8d7b9cd8c464b4926c6..6a53af68cd53ad168d530c4775089a4244e62245 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.platforms import current_platform from vllm.utils import has_deep_ep +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch if has_deep_ep(): @@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep def test_deep_ep_moe( dtype: torch.dtype, @@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False] @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) +@multi_gpu_test(num_gpus=2) @requires_deep_ep def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], num_experts: int, topk: int, diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py new file mode 100644 index 0000000000000000000000000000000000000000..52a3d2ca3b422f7f31d99491a14e1cd399bfe3ac --- /dev/null +++ b/tests/kernels/moe/test_flashinfer.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + input_to_float8) +from vllm.model_executor.models.llama4 import Llama4MoE +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if not has_flashinfer_cutlass_fused_moe( +) or not current_platform.has_device_capability(100): + pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True) + +NUM_EXPERTS = [16] +TOP_KS = [1] + +MNK_FACTORS = [ + (256, 8192, 5120), + (256, 4096, 5120), + (127, 8192, 5120), + (127, 4096, 5120), + (10, 8192, 5120), + (10, 4096, 5120), + (1, 8192, 5120), + (1, 4096, 5120), +] + +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + + +def quant_fp8_per_tensor_batches(a): + num_batches = a.size(0) + a_quant = [] + a_scales = [] + + for i in range(num_batches): + a_fp8, a_global_sf = input_to_float8(a[i]) + a_global_sf = 1.0 / a_global_sf + a_quant.append(a_fp8) + a_scales.append(a_global_sf) + + result_a_quant = torch.stack(a_quant) + result_a_scales = torch.stack(a_scales) + + return result_a_quant, result_a_scales + + +@dataclass +class TestData: + hidden_states: torch.Tensor + w13_quantized: torch.Tensor + w2_quantized: torch.Tensor + a1_scale: torch.Tensor + a2_scale: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor + layer: torch.nn.Module + + @staticmethod + def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, + reorder: bool) -> "TestData": + hidden_states = torch.randn( + (m, k), device="cuda", dtype=torch.bfloat16) / 10 + w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) + w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) + + # Scale to fp8 + _, a1_scale = input_to_float8(hidden_states) + a1_scale = 1.0 / a1_scale + a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to( + dtype=torch.float32) + w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) + w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) + + layer = torch.nn.Module() + layer.w13_weight = w13_quantized.clone() + layer.w2_weight = w2_quantized.clone() + layer.w13_input_scale = a1_scale + layer.w2_input_scale = a2_scale + layer.w13_weight_scale = w13_weight_scale + layer.w2_weight_scale = w2_weight_scale + + register_moe_scaling_factors(layer) + + # flashinfer expects swapped rows for w13 + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + if reorder: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) + layer.custom_routing_function = Llama4MoE.custom_routing_function + layer.intermediate_size_per_partition = n + layer.ep_rank = 0 + layer.local_num_experts = e + + return TestData( + hidden_states=hidden_states, + w13_quantized=w13_quantized, + w2_quantized=w2_quantized, + a1_scale=a1_scale, + a2_scale=a2_scale, + w13_weight_scale=w13_weight_scale, + w2_weight_scale=w2_weight_scale, + layer=layer, + ) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_per_tensor_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( + layer=td.layer, + hidden_states=td.hidden_states, + router_logits=score, + routing_bias=None, + global_num_experts=e, + top_k=topk, + num_expert_group=None, + topk_group=None, + apply_router_weight_on_input=True) + + torch.testing.assert_close(output, + flashinfer_output, + atol=5.5e-2, + rtol=1e-2) + + +@pytest.mark.skip( + "Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472" +) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_cutlass_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + td.layer.dp_size = 1 + + flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( + td.hidden_states, + td.layer, + topk_weights, + topk_ids, + activation="silu", + global_num_experts=e, + expert_map=None, + apply_router_weight_on_input=True, + ) + + torch.testing.assert_close(output, + flashinfer_cutlass_output, + atol=5.5e-2, + rtol=1e-2) diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..646e763194fd648305f3e1abb768615d77d6c7d7 --- /dev/null +++ b/tests/kernels/moe/test_grouped_topk.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the MoE grouped topk kernel + +Run `pytest tests/kernels/moe/test_grouped_topk.py`. +""" +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk, + grouped_topk) +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") +@pytest.mark.parametrize("n_token", [1, 33, 64]) +@pytest.mark.parametrize("n_hidden", [1024, 2048]) +@pytest.mark.parametrize("n_expert", [16]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("num_expert_group", [8]) +@pytest.mark.parametrize("topk_group", [2]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int, + n_hidden: int, n_expert: int, topk: int, + renormalize: bool, num_expert_group: int, + topk_group: int, scoring_func: str, + routed_scaling_factor: float, dtype: torch.dtype): + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), + dtype=dtype, + device="cuda") + gating_output = torch.randn((n_token, n_expert), + dtype=dtype, + device="cuda") + e_score_correction_bias = torch.randn((n_expert, ), + dtype=torch.float32, + device="cuda") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") + baseline_topk_weights, baseline_topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + test_topk_weights, test_topk_ids = fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + if renormalize: + torch.testing.assert_close(baseline_topk_weights, + test_topk_weights, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_topk_ids, + test_topk_ids, + atol=0, + rtol=0) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index d45982384eb3b297f386574a39c467c7987b0f7b..6112183be5475575d2f5b4d90f05bfb9124ce968 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from ...utils import multi_gpu_test from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, reference_moe_impl, run_modular_kernel) @@ -162,6 +163,7 @@ def is_nyi_config(config: Config) -> bool: product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [2]) +@multi_gpu_test(num_gpus=2) @meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, n: int, e: int, dtype: torch.dtype, diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index f1e21ecf564a0490d3f264aa1ab7c1031c0cf113..3b4905e8c24487bcdcf67c0d1e544ce712d4eb1d 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -438,11 +438,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) - torch.cuda.empty_cache() vllm_moe.experts.w2_weight = Parameter(F.pad( vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) + torch.cuda.synchronize() torch.cuda.empty_cache() # Run forward passes for both MoE blocks diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 824b072a9f933530e6a5599bb665ecdf4c0c65b1..7bd1ffce58e9624b88ab4e8c98c8a99abb277ff7 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -4,15 +4,27 @@ import importlib import importlib.metadata from dataclasses import dataclass +from typing import Optional import pytest import torch from packaging import version +from vllm.platforms import current_platform + QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( +) and current_platform.is_device_capability(100) + +if TRTLLM_GEN_MXFP4_AVAILABLE: + from flashinfer import (fp4_quantize, mxfp8_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + @dataclass class ModelCase: @@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) - assert output \ No newline at end of file + assert output + + +def swiglu(x, + alpha: float = 1.702, + beta: float = 1.0, + limit: Optional[float] = None): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + beta) + + +fp4_lookup_table = [ + 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6 +] + + +def mxfp4_dequantize(x, scale): + assert x.dtype == torch.uint8 + x = x.view(torch.uint8).to(torch.int32) + x_unpacked = torch.zeros(*x.shape[:-1], + x.shape[-1] * 2, + dtype=torch.int32, + device=x.device) + x_unpacked[..., 0::2].copy_(x & 0xF) + x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) + + x_float = torch.zeros(x_unpacked.shape, + dtype=torch.float32, + device=x.device) + for i, val in enumerate(fp4_lookup_table): + x_float[x_unpacked == i] = val + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def mxfp8_dequantize(x, scale): + assert x.dtype == torch.float8_e4m3fn + x_float = x.to(torch.float32) + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def reference_moe( + roouting_logits, + topk, + num_experts, + hidden_states, + w13, + bias13, + w2, + bias2, + alpha, + beta, + limit, + act_type, +): + # renormalize routing + experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + t = hidden_states.clone() + # MLP #1 + mlp1_weight = w13[expert_indices, ...] + mlp1_bias = bias13[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, alpha=alpha, beta=beta, limit=limit) + + if act_type == 'mxfp8': + t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16), + is_sf_swizzled_layout=False) + t = mxfp8_dequantize(t_quantized, t_scale) + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + mlp2_bias = bias2[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + assert t.shape == hidden_states.shape + return t.to(torch.bfloat16) + + +def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13_weight, + w13_weight_scale, + w13_bias, + w2_weight, + w2_weight_scale, + w2_bias, + act_type, + alpha, + beta, + limit, +) -> torch.Tensor: + sf_block_size = 32 + assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts + and w13_weight.shape[1] == intermediate_size * 2 + and w13_weight.shape[2] == hidden_size // 2) + assert (w13_weight_scale.dim() == 3 + and w13_weight_scale.shape[0] == num_experts + and w13_weight_scale.shape[1] == intermediate_size * 2 + and w13_weight_scale.shape[2] == hidden_size // sf_block_size) + assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts + and w2_weight.shape[1] == hidden_size + and w2_weight.shape[2] == intermediate_size // 2) + assert (w2_weight_scale.dim() == 3 + and w2_weight_scale.shape[1] == hidden_size + and w2_weight_scale.shape[2] == intermediate_size // sf_block_size) + assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2) + assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size) + + # Swap w1 and w3 as the defenition of + # swiglu is different in the trtllm-gen + w13_weight_scale_ = w13_weight_scale.clone() + w13_weight_ = w13_weight.clone() + w13_bias_ = w13_bias.clone() + w13_weight[:, :intermediate_size, :].copy_( + w13_weight_[:, intermediate_size:, :]) + w13_weight[:, intermediate_size:, :].copy_( + w13_weight_[:, :intermediate_size, :]) + w13_weight_scale[:, :intermediate_size, :].copy_( + w13_weight_scale_[:, intermediate_size:, :]) + w13_weight_scale[:, intermediate_size:, :].copy_( + w13_weight_scale_[:, :intermediate_size, :]) + w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) + w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) + + # Interleave the weights and scaling factors for activation + w13_weight_interleaved = [] + w13_weight_scale_interleaved = [] + w13_bias_interleaved = [] + for i in range(num_experts): + w13_weight_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight[i].clone())) + w13_weight_scale_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())) + w13_bias_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, + 1))) + w13_weight = torch.stack(w13_weight_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2) + w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 32) + w13_bias = torch.stack(w13_bias_interleaved).reshape( + num_experts, 2 * intermediate_size) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_shuffled = [] + gemm2_scales_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(num_experts): + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_shuffled) + w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( + num_experts, 2 * intermediate_size, + hidden_size // sf_block_size).view(torch.float8_e4m3fn) + w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) + + w2_weight = torch.stack(gemm2_weights_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape( + num_experts, hidden_size, + intermediate_size // sf_block_size).view(torch.float8_e4m3fn) + w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) + + tg_result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale, + gemm1_bias=w13_bias, + gemm1_alpha=alpha, + gemm1_beta=beta, + gemm1_clamp_limit=limit, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale, + gemm2_bias=w2_bias, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=topk, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), + routing_method_type=1, # renormalize + do_finalize=True)[0] + return tg_result + + +def check_accuracy(a, b, atol, rtol, percent): + """Allow a mismatch percentage of 1 - percent.""" + if torch.any(torch.isnan(a)): + raise Exception("NaN in reference output") + if torch.any(torch.isnan(b)): + raise Exception("NaN in actual output") + if torch.any(torch.isinf(a)): + raise Exception("Inf in reference output") + if torch.any(torch.isinf(b)): + raise Exception("Inf in actual output") + assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" + + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception( + f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " + f"(threshold: {1-percent:.4f})") + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32, 128]) +@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.skipif( + not TRTLLM_GEN_MXFP4_AVAILABLE, + reason="nvidia gpu and compute capability sm100 is required for this test") +def test_trtllm_gen_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], + act_type: str, +): + seed = 42 + torch.manual_seed(seed) + hidden_states = torch.randn(num_tokens, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16) + w13 = (torch.randn(num_experts, + intermediate_size * 2, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16)) + w2 = (torch.randn(num_experts, + hidden_size, + intermediate_size, + device="cuda:0", + dtype=torch.bfloat16)) + bias13 = torch.randn(num_experts, intermediate_size * 2, + device="cuda:0") * 10 + bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 + router_logits = torch.rand(num_tokens, num_experts, + dtype=torch.float32).cuda() + + w13, w13_scale = fp4_quantize(w13, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, hidden_size // 32) + w2, w2_scale = fp4_quantize(w2, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 32) + if act_type == 'mxfp8': + hidden_states, hidden_states_scale = mxfp8_quantize( + hidden_states, is_sf_swizzled_layout=False) + hidden_states_scale = hidden_states_scale.view( + torch.float8_e4m3fn).reshape(-1) + else: + hidden_states_scale = None + + # reference result + ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) + w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone()) + w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) + bias13_ref = bias13 + bias2_ref = bias2 + if act_type == 'mxfp8': + hidden_states_ref = mxfp8_dequantize( + hidden_states, hidden_states_scale).to(torch.float32) + else: + hidden_states_ref = hidden_states.to(torch.float32) + # Process tokens in chunks of 32 to reduce memory usage + chunk_size = 32 + num_chunks = (num_tokens + chunk_size - 1) // chunk_size + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_result = reference_moe( + router_logits[start_idx:end_idx].to(torch.float32), + topk, + num_experts, + hidden_states_ref[start_idx:end_idx], + w13_ref, + bias13_ref, + w2_ref, + bias2_ref, + alpha, + beta, + limit, + act_type, + ) + ref_result[start_idx:end_idx].copy_(chunk_result) + + # trtllm-gen result + if alpha is not None: + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) + tg_result = tg_mxfp4_moe(router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + bias13, + w2, + w2_scale, + bias2, + act_type, + alpha=alpha, + beta=beta, + limit=limit) + # relatively loose check since the mxfp4 quantization is less accurate + check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index c2064de97358f64b37bceb27d606df4ac8686fa1..3f36d7ada2e94de7e778cc309e2ba643e27dfd2c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.platforms import current_platform from vllm.utils import round_up +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( @@ -452,6 +453,7 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize_slow( mnk: tuple[int, int, int], e: int, @@ -740,6 +742,7 @@ def _pplx_moe( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe_slow( mnk: tuple[int, int, int], e: int, @@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize( world_dp_size: tuple[int, int], use_internode: bool, @@ -893,6 +897,7 @@ def test_pplx_prepare_finalize( @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe( world_dp_size: tuple[int, int], use_internode: bool, diff --git a/tests/kernels/moe/untest_block_fp8.py b/tests/kernels/moe/untest_block_fp8.py index 9e4eaf221f245bf50fb7d27c002b0eff00e4ca9e..ecc57acc679634c82e23bd6c51c855a3f12f3fc1 100644 --- a/tests/kernels/moe/untest_block_fp8.py +++ b/tests/kernels/moe/untest_block_fp8.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): diff --git a/tests/kernels/moe/untest_cutlass_moe.py b/tests/kernels/moe/untest_cutlass_moe.py index 81fb3ec1de188854783169d45eea957756a6ad07..c84f66383b9027acfbb75020332c5d5d26bdd657 100644 --- a/tests/kernels/moe/untest_cutlass_moe.py +++ b/tests/kernels/moe/untest_cutlass_moe.py @@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'topk_ids': topk_ids, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, + 'ab_strides1': moe_tensors.ab_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides1': moe_tensors.c_strides1, + 'c_strides2': moe_tensors.c_strides2, 'per_act_token': per_act_token, 'a1_scale': None #moe_tensors.a_scale } @@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8( topk_ids[0][1] = 1 workspace13_shape = (m * topk, max(2 * n, k)) - workspace2_shape = (m * topk, n) - output_shape = (m * topk, k) + workspace2_shape = (m * topk, max(n, k)) + output_shape = (m, k) workspace13 = torch.empty(prod(workspace13_shape), device="cuda", @@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, torch.float8_e4m3fn, @@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8( output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, - per_act_token, per_out_channel, False) + a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, + workspace13, workspace2, None, mt.a.dtype, per_act_token, + per_out_channel, False, topk_weights) workspace13.random_() output_random_workspace = torch.empty(output_shape, diff --git a/tests/kernels/moe/untest_moe_permute_unpermute.py b/tests/kernels/moe/untest_moe_permute_unpermute.py index 6ca01f9271bbadcf0e09ae18e99d239eef8520dd..d71664d94b9c8de6f488f8d9774685afcfa7b6ad 100644 --- a/tests/kernels/moe/untest_moe_permute_unpermute.py +++ b/tests/kernels/moe/untest_moe_permute_unpermute.py @@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, atol=0, rtol=0) # check mindice - torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # current kernel usage assumes deepgemm requires align_block_size + # when it's not provided then we don't compute m_indices (for cutlass) + if align_block_size is not None: + torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # check permuted_hidden_states, only valid token torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], permuted_hidden_states[valid_row_idx], diff --git a/tests/kernels/moe/untest_pplx_cutlass_moe.py b/tests/kernels/moe/untest_pplx_cutlass_moe.py index f98937ee6c5276258d07093dc74d17cb079adb01..9e78f4d6e4da033fe6b0f50b5e85d2f1f650765f 100644 --- a/tests/kernels/moe/untest_pplx_cutlass_moe.py +++ b/tests/kernels/moe/untest_pplx_cutlass_moe.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.platforms import current_platform from vllm.utils import cdiv +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch try: @@ -76,6 +77,7 @@ def pplx_cutlass_moe( assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape + intermediate_dim = w2.shape[2] num_experts = w1.shape[0] block_size = hidden_dim # TODO support more cases device = pgi.device @@ -124,8 +126,27 @@ def pplx_cutlass_moe( num_local_experts=num_local_experts, num_dispatchers=num_dispatchers) + ab_strides1 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + ab_strides2 = torch.full((num_local_experts, ), + intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_local_experts, ), + 2 * intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, - out_dtype, per_act_token, per_out_ch) + out_dtype, per_act_token, per_out_ch, + ab_strides1, ab_strides2, c_strides1, + c_strides2) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -227,6 +248,7 @@ def _pplx_moe( @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) @pytest.mark.parametrize("use_internode", [False]) +@multi_gpu_test(num_gpus=2) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), diff --git a/tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py index 673a0aa367948845fc48e3e1d93a1791856077c9..5a0379dfb4475a094dc92886c8345b62621bcef6 100644 --- a/tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py @@ -24,7 +24,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): current_platform.seed_everything(seed) # Input tensor of shape (E, T, 2*H) - y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda") + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( low=0, high=T, @@ -74,7 +74,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): y_se = y_s[e] y_qe = y_q[e] - torch.testing.assert_close(y_se[:nt], ref_s[:nt]) + torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( y_qe[:nt].to(torch.float32), ref_q[:nt].to(torch.float32), diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py index c8eb46770fdf5d8b37db316c8d357e89ffe35c5d..478b9ae534fddf142e9f713db8a5d8c11767b410 100644 --- a/tests/kernels/quantization/test_awq_triton.py +++ b/tests/kernels/quantization/test_awq_triton.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the AWQ Triton kernel. -Run `pytest tests/kernels/test_awq_triton.py`. +Run `pytest tests/kernels/quantization/test_awq_triton.py`. """ import pytest import torch diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py new file mode 100644 index 0000000000000000000000000000000000000000..f659408efe8c68d056fbc456411158b77a1dd27e --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the CUTLASS W4A8 kernel. + +Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`. +""" + +from dataclasses import dataclass +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672), + (13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096), + (64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096), + (1024, 4096, 8192), (1024, 8192, 4096)] + +# TODO(czhu): get supported schedules from fn +SCHEDULES = [ + '128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1', + '128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1', + '128x256_1x1x1', '128x256_2x1x1' +] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: torch.Tensor + w_ch_s: torch.Tensor + w_tok_s: torch.Tensor + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + *( + TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=w_type, + output_type=o_type, + group_scale_type=torch.float8_e4m3fn, + channel_scale_type=torch.float32, + token_scale_type=torch.float32) + for w_type in [scalar_types.int4] + # TODO(czhu): fp16 out type + for o_type in [torch.bfloat16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +# For testing quantized linear kernels +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, + max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights(w, + wtype, + group_size=group_size, + zero_points=zero_points) + + # since scales are cast to fp8, we need to compute w_ref this way + w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to( + torch.float32).repeat_interleave(group_size, dim=0)).to(atype) + + # bit mask prevents sign extending int4 when packing + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q) + w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype)) + + return w_ref, w_q_packed, w_s_packed, w_zp + + +def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> Tensors: + m, n, k = shape + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = to_fp8(torch.randn((m, k), device="cuda")) + w = to_fp8(torch.randn((k, n), device="cuda")) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + False) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + # for the practical use case we need per-tok scales for fp8 activations + w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type) + # weights are already per-group quantized, use placeholder here + w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +def mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + # CUTLASS upstream uses fp8 with fastaccum as reference + # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 + output_ref = torch._scaled_mm( + tensors.a_ref.to(types.act_type), + tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major + tensors.w_tok_s.unsqueeze(1), + tensors.w_ch_s.unsqueeze(0), + out_dtype=types.output_type, + use_fast_accum=True) + + output = ops.cutlass_w4a8_mm( + a=tensors.a, + b_q=tensors.w_q, + b_group_scales=tensors.w_g_s, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + ) + + print(output) + print(output_ref) + + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=1e-3, + atol=1e-3) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +@pytest.mark.parametrize("schedule", SCHEDULES) +def test_cutlass_w4a8(shape, types: TypeConfig, schedule): + group_sizes = [128] + for group_size in group_sizes: + tensors = create_test_tensors(shape, types, group_size) + mm_test_helper(types, tensors, group_size, schedule) + + +# Test to make sure cuda graphs work +class W4A8Layer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.cutlass_w4a8_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +def test_w4a8_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = to_fp8(torch.randn((m, k), device="cuda")) + b = to_fp8(torch.randn((k, n), device="cuda")) + + wtype = scalar_types.int4 + stype = torch.float8_e4m3fn + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points) + + w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32) + w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32) + + # Construct a trivial model with a single layer that calls the kernel + model = W4A8Layer( + b_q=w_q_packed, + b_group_scales=w_s, + b_group_size=group_size, + b_channel_scales=w_ch_s, + a_token_scales=w_tok_s, + ) + + output_ref = torch._scaled_mm( + a, + w_ref.to(a.dtype).t().contiguous().t(), # col major + w_tok_s.unsqueeze(1), + w_ch_s.unsqueeze(0), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + + output.zero_() + g.replay() + + torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) diff --git a/tests/kernels/quantization/test_flashinfer_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_scaled_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..9f669c6df8bd5fc579af41491a49c943e490574b --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_scaled_mm.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason= + "Flashinfer FP8 gemms requires compute capability of 10.0 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_fp8_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + use_bias: bool, + seed: int, + device: str, + autotune: bool, +) -> None: + current_platform.seed_everything(seed) + m, n, k = shape + a = torch.randn((m, k), dtype=dtype, device=device) + b = torch.randn((n, k), dtype=dtype, device=device) / k + + a_fp8, a_scale = ops.scaled_fp8_quant(a) + b_fp8, b_scale = ops.scaled_fp8_quant(b) + + expected_out = torch.mm( + a_scale * a_fp8.to(dtype=torch.float32), + b_scale * b_fp8.to(dtype=torch.float32).t(), + ).to(dtype=dtype) + + if use_bias: + bias = torch.randn((n, ), dtype=dtype, device=device) + expected_out = expected_out + bias + else: + bias = None + + import flashinfer + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp8_mm( + a_fp8, + b_fp8.t(), + a_scale, + b_scale, + dtype, + bias=bias, + ) + + torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..969f14cc3fe626f90482dab4ed7cc948482dd74a --- /dev/null +++ b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +BLOCK_SIZE = 16 + + +def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, + global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + silu_and_mul_out = silu_and_mul.forward_native(x) + assert not current_platform.is_rocm() + assert silu_and_mul_out.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.') + other_dims = 1 if silu_and_mul_out.ndim == 1 else -1 + silu_and_mul_out = silu_and_mul_out.reshape(other_dims, + silu_and_mul_out.shape[-1]) + m, n = silu_and_mul_out.shape + device = silu_and_mul_out.device + + # Two fp4 values will be packed into an uint8. + out = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + output_scale = ref_output_scale + + torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale, + global_scale) + + return out, output_scale + + +def ops_impl(x: torch.Tensor, global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + out_shape = (x.shape[0], x.shape[1] // 4) + output_scale = ref_output_scale + out = torch.empty(out_shape, dtype=torch.uint8, device=x.device) + torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale) + return out, output_scale + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_quantize_to_fp4( + dtype: torch.dtype, + shape: tuple[int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + m, n = shape + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + block_size = 16 + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert x.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.') + + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(x.shape[0], 128) + scale_n = x.shape[1] // (2 * block_size) + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=x.device, + dtype=torch.int32) + + layer = SiluAndMul() + + ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale) + + fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale) + + assert ref_out.dtype == torch.uint8 + assert fusion_out.dtype == torch.uint8 + assert ref_out.shape == fusion_out.shape + + assert ref_out_scale.dtype == torch.int32 + assert fusion_out_scale.dtype == torch.int32 + assert ref_out_scale.shape == fusion_out_scale.shape + + # Allow up to 2% of mismatched values since BF16 has accuracy issues. + mis_threshold = 0.02 + atol = 0.4 + rtol = 0.4 + ref_logits = ref_out[-1] + fusion_logits = fusion_out[-1] + + mis_count = torch.sum( + torch.abs(fusion_logits - ref_logits) > (atol + + rtol * torch.abs(ref_logits))) + mis_ratio = mis_count / fusion_logits.numel() + + assert mis_ratio < mis_threshold, \ + f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}" + + torch.testing.assert_close(ref_out_scale, fusion_out_scale) + + opcheck(torch.ops._C.silu_and_mul_nvfp4_quant, + (fusion_out, fusion_out_scale, x, global_scale)) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 3ddba9ee8f527ae8ce16b264584312b60b681749..de6e5f72f26bde78ccaa706e7ef4e19321d58e1b 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the triton_scaled_mm kernel -Run `pytest tests/kernels/test_triton_scaled_mm.py`. +Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`. """ import os import importlib diff --git a/tests/kernels/quantization/untest_cutlass_2of4_sparse.py b/tests/kernels/quantization/untest_cutlass_2of4_sparse.py index b81e46aef596d910051b5b4f94011e277f4eae79..0b94a662c62ce39ff89d16dfb8a0b599179bc97c 100644 --- a/tests/kernels/quantization/untest_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/untest_cutlass_2of4_sparse.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for sparse cutlass kernels -Run `pytest tests/kernels/test_semi_structured.py`. +Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`. """ import pytest diff --git a/tests/kernels/quantization/untest_cutlass_scaled_mm.py b/tests/kernels/quantization/untest_cutlass_scaled_mm.py index 8730eeaaa761c67f0ca3903ebbb0f011f32b25f3..65320509e173f36fa9e9e371ba2d5369ed1a0804 100644 --- a/tests/kernels/quantization/untest_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/untest_cutlass_scaled_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for cutlass kernels -Run `pytest tests/kernels/test_cutlass.py`. +Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`. """ import random @@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, expert_offsets = torch.zeros((num_experts + 1), device=device, - dtype=torch.int32) + dtype=torch.int64) problem_sizes = torch.zeros((num_experts, 3), device=device, diff --git a/tests/kernels/quantization/untest_machete_mm.py b/tests/kernels/quantization/untest_machete_mm.py index 969d54528c1dbe47b100cdc51ef70c4692c40aef..17ec0a08e528c2a71a493b778aad0dc3770cd133 100644 --- a/tests/kernels/quantization/untest_machete_mm.py +++ b/tests/kernels/quantization/untest_machete_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the machete kernel. -Run `pytest tests/kernels/test_machete_mm.py`. +Run `pytest tests/kernels/quantization/test_machete_mm.py`. """ import math @@ -95,23 +95,23 @@ TEST_TYPES = [ token_scale_type=None) for w_type in [scalar_types.uint4, scalar_types.uint8] for a_type in [torch.float16, torch.bfloat16]), - # QQQ style - *(TypeConfig(act_type=torch.int8, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), - *(TypeConfig(act_type=torch.float8_e4m3fn, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), + # # QQQ style + # *(TypeConfig(act_type=torch.int8, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), + # *(TypeConfig(act_type=torch.float8_e4m3fn, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), ] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel diff --git a/tests/kernels/quantization/untest_marlin_gemm.py b/tests/kernels/quantization/untest_marlin_gemm.py index c354dcd3568f487dd20b01b875e8954cb01b3d55..0be020085bfa42698d4090f3cb6b982c39b515e2 100644 --- a/tests/kernels/quantization/untest_marlin_gemm.py +++ b/tests/kernels/quantization/untest_marlin_gemm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the marlin kernel. -Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. """ import pytest import torch @@ -13,11 +13,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, query_marlin_supported_quant_types) @@ -31,8 +27,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_weights) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 - marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) from vllm.scalar_type import scalar_types @@ -449,68 +443,6 @@ def test_hqq_marlin_gemm( assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("qqq"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS) -@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_marlin_qqq_gemm( - k_chunk, - n_chunk, - num_bits, - group_size, - mnk_factors, -): - int8_traits = torch.iinfo(torch.int8) - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - # Quantize activations - s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to( - torch.float) - q_a = (a_input / s_a).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - - # Quantize weights - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \ - marlin_qqq_quantize(b_weight, num_bits, group_size) - - workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_MAX_PARALLEL) - - opcheck(torch.ops._C.marlin_qqq_gemm, - (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel, - marlin_qqq_s_group, workspace.scratch, a_input.shape[0], - b_weight.shape[1], a_input.shape[1])) - - output = ops.marlin_qqq_gemm( - q_a, - marlin_qqq_q_w, - s_a, - marlin_qqq_s_channel, - marlin_qqq_s_group, - workspace.scratch, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - ) - output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - def test_marlin_gemm_subset_input(): quant_type = scalar_types.uint4b8 group_size = 128 @@ -602,18 +534,3 @@ def test_marlin_gemm_with_bias(size_m): max_diff = compute_max_diff(output, output_ref) assert max_diff < 0.04 - - -def test_marlin_gemm_opcheck(): - size_m = 2048 - size_n = 4096 - size_k = 4096 - a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) - w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) - s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) - wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL).scratch - x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - torch.testing.assert_close(x, y) - opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k)) \ No newline at end of file diff --git a/tests/kernels/quantization/untest_triton_scaled_mm.py b/tests/kernels/quantization/untest_triton_scaled_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..de6e5f72f26bde78ccaa706e7ef4e19321d58e1b --- /dev/null +++ b/tests/kernels/quantization/untest_triton_scaled_mm.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the triton_scaled_mm kernel + +Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`. +""" +import os +import importlib +from typing import Optional + +import pytest +import torch + +from vllm.platforms import current_platform +from ...utils import models_path_prefix + +device = "cuda" + +triton_scaled_mm_module = importlib.import_module( + "vllm.model_executor.layers.quantization.compressed_tensors." + "triton_scaled_mm") +triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + + +def torch_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out = torch.mm(a.to(torch.float32), b.to(torch.float32)) + out = scale_a * out + out = scale_b.T * out + out = out.to(out_dtype) + if bias is not None: + out = out + bias + + return out + + +def get_8bit_types(): + types = [torch.int8] + if current_platform.supports_fp8(): + types.append(current_platform.fp8_dtype()) + return types + + +# This test is to check regressions for int8 support on ROCm. +@pytest.mark.parametrize("model_path", [ + os.path.join(models_path_prefix, "neuralmagic/Llama-3.2-1B-quantized.w8a8"), +]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [10]) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="Should only run on ROCm") +def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, + max_tokens, num_logprobs): + dtype = "bfloat16" + + with vllm_runner(model_path, dtype=dtype) as vllm_model: + vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, + num_logprobs) + + +MNK_FACTORS = [ + (1, 256, 128), + (33, 256, 496), + (64, 971, 1024), + (64, 20486, 128), + (512, 256, 496), + (512, 20486, 1024), +] + + +@pytest.mark.parametrize("M,N,K", MNK_FACTORS) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("in_dtype", get_8bit_types()) +@pytest.mark.parametrize("use_scalar_scale_a", [True, False]) +@pytest.mark.parametrize("use_scalar_scale_b", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, + use_scalar_scale_b, use_bias): + is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t + ).is_floating_point() + + current_platform.seed_everything(0) + + # NOTE: There are cases, where if the matrix is large enough, an output + # like 65504.4 can be produced, and can easily turn into inf when + # multiplied when using float16/bfloat16. This means one function, e.g., + # testing function, and another function, e.g. golden function, can + # produce a non-inf value while the other produces an inf value, and + # will cause assert_close/allclose to fail, even though if overflow + # wouldn't have occurred, the values would have been "close." + # + # So, the values here are kept small enough to avoid this situation. + if is_floating_point_type(in_dtype): + a = (0.25 * torch.rand( + (M, K), dtype=torch.float32, device=device)).to(in_dtype) + b = (0.25 * torch.rand( + (K, N), dtype=torch.float32, device=device)).to(in_dtype) + else: + a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device) + b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device) + + if use_scalar_scale_a: + scale_a = torch.rand((1, 1), device=device) + else: + scale_a = 0.25 * torch.rand((M, 1), device=device) + + if use_scalar_scale_b: + scale_b = torch.rand((1, 1), device=device) + else: + scale_b = 0.25 * torch.rand((N, 1), device=device) + + bias = None + if use_bias: + bias = torch.rand((N, ), device=device, dtype=out_dtype) + + c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + c_actual = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + torch.testing.assert_close(c_check, c_actual, rtol=1e-1, atol=1e-1) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index 9b7c526d670029598d1eff560e95bfb48dc160bc..bc556dedc2617e44a30a8b069ec94ac6b43e700f 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -10,13 +10,19 @@ import pytest import torch from packaging import version -from vllm import SamplingParams + +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config) +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionMetadataBuilder) from ..utils import models_path_prefix -from ..models.utils import check_embeddings_close +from ..models.utils import check_embeddings_close, check_logprobs_close TORCH_VERSION = version.parse(torch.__version__) MINIMUM_TORCH_VERSION = version.parse("2.7.0") +DIRECT_BUILD_VERSION = version.parse("2.9.dev0") def set_seed(seed): @@ -36,22 +42,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with - the default backend, ensuring they are identical when using the same seed. + the default backend, ensuring they are similar when using the same seed. """ model_name = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct") seed = 42 max_tokens = 24 + num_logprobs = 5 prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", ] - sampling_params = SamplingParams(temperature=0.0, - top_p=1.0, - seed=seed, - max_tokens=max_tokens) - # Run with flex attention with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -63,7 +65,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): tensor_parallel_size=1, num_gpu_blocks_override=128, enforce_eager=True) as llm_flex: - output_flex = llm_flex.generate(prompts, sampling_params) + output_flex = llm_flex.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) # Run with default backend with monkeypatch.context() as m: @@ -73,20 +76,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): runner="generate", tensor_parallel_size=1, num_gpu_blocks_override=128, - enforce_eager=True) as llm_default: - output_default = llm_default.generate(prompts, sampling_params) - - # Compare outputs from both backends - for i, (flex_result, - default_result) in enumerate(zip(output_flex, output_default)): - prompt = prompts[i] - flex_text = flex_result[1][0] - default_text = default_result[1][0] - - assert flex_text == default_text, ( - f"FlexAttention output doesn't match default for: {prompt!r}\n" - f"FlexAttention: {flex_text!r}\n" - f"Default: {default_text!r}") + enforce_eager=True, + gpu_memory_utilization=0.85) as llm_default: + output_default = llm_default.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=output_flex, + outputs_1_lst=output_default, + name_0="flex", + name_1="default", + ) @pytest.mark.skipif( @@ -138,5 +138,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ) +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_block_mask_direct_vs_slow_path(): + """Test that direct path block mask is a superset of slow path. + + The direct path may include extra blocks for performance (over-estimation), + but must include all blocks that the slow path determines are necessary. + """ + device = torch.device("cuda") + + vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B", + block_size=16, + max_model_len=1024) + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # Use a mixed batch that will create groups spanning multiple sequences + batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256], + query_lens=[33, 5, 32, 64], + name="test_mixed_batch") + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, + device) + + metadata_direct = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + builder.direct_build = False + metadata_slow = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + assert metadata_direct.block_mask is not None + assert metadata_slow.block_mask is not None + + # Extract block indices for comparison, B, H are the same + direct_indices = metadata_direct.block_mask.kv_indices[0, 0] + slow_indices = metadata_slow.block_mask.kv_indices[0, 0] + direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0] + slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0] + + # main test: every block needed by slow path must be in direct path + num_groups = direct_num.shape[0] + all_contained = True + missing_details = [] + + for group_idx in range(num_groups): + direct_blocks = set( + direct_indices[group_idx, :direct_num[group_idx]].tolist()) + slow_blocks = set( + slow_indices[group_idx, :slow_num[group_idx]].tolist()) + + missing_blocks = slow_blocks - direct_blocks + if missing_blocks: + all_contained = False + missing_details.append( + f"Group {group_idx}: missing {sorted(missing_blocks)}") + + assert all_contained, ( + "Direct path is missing blocks required by slow path:\n" + + "\n".join(missing_details)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py new file mode 100644 index 0000000000000000000000000000000000000000..17692384ac9a9f0b279678ff678912f75ffa369b --- /dev/null +++ b/tests/kernels/test_onednn.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for FlexAttention backend vs default backend""" + +from typing import Optional + +import pytest +import torch + +from tests.kernels.utils import to_int8 +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +NK_FACTORS = [ + (256, 128), + (4096, 4096), + (16384, 4096), + (1023, 491), + (1001, 15), +] +M_FACTORS = [ + (16, 1, 32, 128, 64), + (1, 17, 1, 31, 17), +] +CACHE_SIZES = [2] +DTYPE = [torch.bfloat16] + + +def rand_int8(shape: tuple, device: str = "cpu"): + return to_int8(torch.rand(shape, device=device) * 255 - 128) + + +def ref_int8_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + azp: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + output_type: torch.dtype, +): + if azp is not None: + a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32) + output = torch.mm((scale_a * a.to(dtype=torch.float32)), + (scale_b * b.to(dtype=torch.float32))) + if bias is not None: + output += bias.float() + + return output.to(dtype=output_type) + + +def onednn_int8_gemm_test_helper(primitive_cache_size: int, + m: int, + n: int, + k: int, + per_tensor_a_quant: bool, + per_tensor_b_quant: bool, + use_azp: bool, + use_bias: bool, + out_dtype: torch.dtype = torch.bfloat16, + device: str = "cpu"): + # Test for a oneDNN kernel with per-tensor / per-token activation + # quantization and per-tensor / per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1) + b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n) + + scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) + scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + + if use_azp: + azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5 + azp = (azp / scale_a).round().to(dtype=torch.int32) + azp_adj = scale_b * b.sum(dim=0, keepdim=True, dtype=torch.float32) + else: + azp = None + azp_adj = None + + if use_bias: + bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 + else: + bias = None + + handler = ops.create_onednn_scaled_mm( + b, + scale_b, + out_dtype, + not per_tensor_a_quant, + use_azp, + primitive_cache_size, + ) + + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, bias) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, bias, out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + if use_bias: + # To test runtime bias setting + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, + out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("per_tensor_a_scale", [True, False]) +@pytest.mark.parametrize("per_tensor_b_scale", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_azp", [True, False]) +@pytest.mark.parametrize("output_type", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_int8_scaled_gemm( + n: int, + k: int, + m_list: tuple[int], + per_tensor_a_scale: bool, + per_tensor_b_scale: bool, + use_bias: bool, + use_azp: bool, + output_type: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_int8_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + per_tensor_a_quant=per_tensor_a_scale, + per_tensor_b_quant=per_tensor_b_scale, + use_bias=use_bias, + use_azp=use_azp, + out_dtype=output_type, + ) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 73404b08edf19aea084933c506ae00516d6d1b91..ed268ae1100ea89aa5e09269c8dd3a5447671a51 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -3,7 +3,7 @@ import tempfile from collections import OrderedDict -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest import os @@ -11,8 +11,6 @@ import torch import torch.nn as nn from huggingface_hub import snapshot_download -import vllm -from vllm.config import LoRAConfig from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) @@ -22,7 +20,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform from ..utils import models_path_prefix @@ -106,6 +103,7 @@ def dummy_model() -> nn.Module: ])) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 return model @@ -139,6 +137,8 @@ def dummy_model_gate_up() -> nn.Module: ], } model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 + return model @@ -223,40 +223,6 @@ def tinyllama_lora_files(): return os.path.join(models_path_prefix, "jashing/tinyllama-colorist-lora") -@pytest.fixture(scope="session") -def phi2_lora_files(): - # return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") - return os.path.join(models_path_prefix, "isotr0py/phi-2-test-sql-lora") - -@pytest.fixture(scope="session") -def qwen_lora_files(): - # return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider") - return os.path.join(models_path_prefix, "customize/qwen-nl2dsl-lora") - - -@pytest.fixture -def llama_2_7b_engine_extra_embeddings(): - cleanup_dist_env_and_memory(shutdown_ray=True) - get_model_old = get_model - - def get_model_patched(**kwargs): - kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, - max_lora_rank=8) - return get_model_old(**kwargs) - - with patch("vllm.worker.model_runner.get_model", get_model_patched): - engine = vllm.LLM(os.path.join(models_path_prefix, "meta-llama/Llama-2-7b-hf"), enable_lora=False) - yield engine.llm_engine - del engine - cleanup_dist_env_and_memory(shutdown_ray=True) - - -@pytest.fixture -def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): - yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. - model_runner.model) - - @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index e0bca3457233fb9dc935fd54d497ed93a82628cd..e866464150d2a4de261cd2d7c6f0f9f1454dd7a4 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -6,7 +6,6 @@ import time import os import pytest -import vllm.envs as env from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) @@ -100,12 +99,10 @@ async def test_add_lora(chatglm3_lora_files): # Run with warmup add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests] add_lora_results = await asyncio.gather(*add_lora_tasks) - if env.VLLM_USE_V1: - # Test that all all_lora calls are successful. - assert all(add_lora_results) - else: - # No way to check V0 engine results as the calls just return None. - pass + + # Test that all all_lora calls are successful. + assert all(add_lora_results) + time_with_add_lora = await requests_processing_time( llm, warmup_run_requests) diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 63b07bcaad944c400a13afae47135d07cee42995..2b8740d6ced065e9a1ab62262f342aa826e74aed 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -89,6 +89,9 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): + # https://github.com/NVIDIA/nccl/issues/1790, set a lower value for + # gpu_memory_utilization here because NCCL >= 2.26.3 seems to use + # more GPU memory causing vLLM to OOM llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, @@ -97,7 +100,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + gpu_memory_utilization=0.85) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 92db023babc28efb6c0bb2e424f021c30a246425..6e2dda464d8eb74a279150c9dd1874fb82fab517 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -243,7 +243,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -347,7 +347,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: @torch.inference_mode() # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -486,7 +486,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) @@ -620,12 +620,15 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_replicated(dist_init, num_loras, device, stage, - bias_enabled) -> None: +def test_linear_replicated( + dist_init, + num_loras, + device, + stage, +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -634,10 +637,11 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16, + ) def create_random_linear_replicated_layer(): @@ -651,10 +655,6 @@ def test_linear_replicated(dist_init, num_loras, device, stage, lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -734,14 +734,13 @@ def test_linear_replicated(dist_init, num_loras, device, stage, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -750,11 +749,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_random_linear_parallel_layer(): if orientation == "row": @@ -777,10 +777,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -860,14 +857,13 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -876,11 +872,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_column_parallel_packed_layer(): if repeats == 2: @@ -924,10 +921,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, model_config=FakeConfig()) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 067618a1b54a0365d5ca85e3945455999b75ff25..ae4f74efd68032af2e7aae5c65fb6bf9c6f9019b 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -114,8 +114,7 @@ def test_llama_lora(sql_lora_files): enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, - max_loras=4, - enable_chunked_prefill=True) + max_loras=4) generate_and_test(llm, sql_lora_files) @@ -129,7 +128,6 @@ def test_llama_lora_tp4(sql_lora_files): max_num_seqs=16, max_loras=4, tensor_parallel_size=4, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) @@ -145,7 +143,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): max_loras=4, tensor_parallel_size=4, fully_sharded_loras=True, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) diff --git a/tests/lora/test_multi_loras_with_tp.py b/tests/lora/test_llm_with_multi_loras.py similarity index 80% rename from tests/lora/test_multi_loras_with_tp.py rename to tests/lora/test_llm_with_multi_loras.py index fe9bd3f26951526c4aaf26abdf109435a16e5f55..3d8dd512a201939b65297073e847afaf0a38bc9e 100644 --- a/tests/lora/test_multi_loras_with_tp.py +++ b/tests/lora/test_llm_with_multi_loras.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Script to test multi loras service with tp >= 2 +This script contains: +1. test multi loras service with tp >= 2 +2. test multi loras request """ +import pytest + from tests.utils import multi_gpu_test from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest @@ -156,3 +160,34 @@ def test_multi_loras_with_tp_sync(): output_text = call_llm_get_outputs(prompt, "Alice") check_outputs(output_text, expected_output) + + +def test_multiple_lora_requests(): + llm = LLM( + model=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_lora_rank=LORA_RANK, + max_model_len=512, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + PROMPTS = ["Hello, my name is"] * 2 + LORA_NAME = "Alice" + lora_request = [ + LoRARequest(LORA_NAME + str(idx), idx + 1, + LORA_NAME_PATH_MAP[LORA_NAME]) + for idx in range(len(PROMPTS)) + ] + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(PROMPTS, lora_request=lora_request) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) + + # Single LoRARequest should be applied to every prompt + single_lora_request = lora_request[0] + outputs = llm.generate(PROMPTS, lora_request=single_lora_request) + assert len(PROMPTS) == len(outputs) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8f8a27006cf67ddd0737c28d0bf4e8c029a12a37..c9ab32edc7f32d703ae0fd29db27ad50b0004d15 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.platforms import current_platform +from .utils import create_peft_lora + EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", @@ -35,17 +37,6 @@ DEVICES = ([ DEFAULT_DTYPE = torch.get_default_dtype() -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Some tests depend on V0 internals. Since both V0 and V1 use the same - LoRAModelManager it is okay to just test V0. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( @@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): max_loras=2, lora_dtype=DEFAULT_DTYPE), device=device) - assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity @@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) -def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, + tmp_path): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) worker_adapter_manager = LRUCacheWorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + 4, 2, + dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size, + lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device @@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, @pytest.mark.parametrize("device", DEVICES) -def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, + tmp_path): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - + 4, 2, dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager(dummy_model_gate_up) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model_gate_up, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 @@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 60663f0b5db519ade1f21b4c9abccccec7c8723c..7f0ff55087e65fad0ad25a67f3a1488fc808feac 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -52,7 +52,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): max_loras=4, distributed_executor_backend="ray", tensor_parallel_size=tp_size, - enable_chunked_prefill=True, ) expected_lora_output = [ diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 76f1bdf70aa3475d6c7cb0e3be4c63412117c0ca..f6bf78104844b4785b1751eb4d0e04c96bfb3df6 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -4,17 +4,15 @@ import os import random import tempfile -from typing import Union from unittest.mock import patch -import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.v1.worker.gpu_worker import Worker as V1Worker -from vllm.worker.worker import Worker + +from vllm.v1.worker.gpu_worker import Worker from ..utils import models_path_prefix NUM_LORAS = 16 @@ -23,18 +21,11 @@ NUM_LORAS = 16 @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - def set_active_loras(worker: Union[Worker, V1Worker], - lora_requests: list[LoRARequest]): + def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) - if isinstance(worker, Worker): - # v0 case - worker.model_runner.set_active_loras(lora_requests, lora_mapping) - else: - # v1 case - worker.model_runner.lora_manager.set_active_adapters( - lora_requests, lora_mapping) - worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker + worker.model_runner.lora_manager.set_active_adapters( + lora_requests, lora_mapping) vllm_config = VllmConfig( model_config=ModelConfig( @@ -63,7 +54,7 @@ def test_worker_apply_lora(sql_lora_files): max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS), ) - worker = worker_cls( + worker = Worker( vllm_config=vllm_config, local_rank=0, rank=0, diff --git a/tests/lora/utils.py b/tests/lora/utils.py index cc1b0d81955bc05c08ae7e5321724092bd15b5f5..7cda90787b6f154e9c5bef312b2af65800432778 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os from dataclasses import dataclass from typing import Optional, Union import torch +from safetensors.torch import save_file from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights @@ -340,3 +343,76 @@ def generate_data_for_nslices( seq_len_tensor, indices, ) + + +def create_peft_lora( + model: torch.nn.Module, + save_dir: str, + target_modules: list[str], + rank: int = 8, + alpha: int = 16, + dropout: float = 0.1, + lora_dtype: torch.dtype = torch.float16, +) -> dict[str, torch.Tensor]: + lora_weights = {} + adapter_config = { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": "dummy_model", + "revision": None, + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": rank, + "lora_alpha": alpha, + "lora_dropout": dropout, + "fan_in_fan_out": False, + "bias": "none", + "modules_to_save": None, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + "target_modules": target_modules, + "exclude_modules": None, + "use_rslora": False, + "use_dora": False, + "loftq_config": None, + } + + for module_name in target_modules: + + module = model + for attr in module_name.split("."): + module = getattr(module, attr) + + if hasattr(module, "input_size") and hasattr(module, "output_size"): + + in_features = module.input_size + out_features = module.output_size + + elif hasattr(module, "embedding_dim") and hasattr( + module, "num_embeddings"): + # ParallelLMHead + in_features = module.embedding_dim + out_features = module.num_embeddings + else: + raise ValueError( + f"Unable to determine dimensions for module {module_name}") + + lora_A = torch.randn(rank, in_features, dtype=lora_dtype) + + torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5) + + lora_B = torch.zeros(out_features, rank, dtype=lora_dtype) + + # PEFT style + lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A + lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B + + config_path = os.path.join(save_dir, "adapter_config.json") + with open(config_path, "w", encoding="utf-8") as f: + json.dump(adapter_config, f, indent=2, ensure_ascii=False) + + weights_path = os.path.join(save_dir, "adapter_model.safetensors") + save_file(lora_weights, weights_path) + + return lora_weights diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index cda341f4a71922532b4fac95741ecf5afe41a280..18a49d7ae55a908c074a276b08ecde5d826e501c 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -94,7 +94,8 @@ AITER_MODEL_LIST = [ pytest.param( os.path.join(models_path_prefix, "allenai/OLMoE-1B-7B-0924-Instruct"), marks=[pytest.mark.cpu_model], - ) + ), + pytest.param("swiss-ai/Apertus-8B"), # apertus ]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index d8b0a8f647a667df1b533a4dc134f3c5b6f58e3a..5d2ec93a4c5fadae8b76a13192f23fa4a088049c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -33,6 +33,7 @@ HYBRID_MODELS = [ os.path.join(models_path_prefix, "hmellor/tiny-random-BambaForCausalLM"), os.path.join(models_path_prefix, "ibm-granite/granite-4.0-tiny-preview"), os.path.join(models_path_prefix, "tiiuae/Falcon-H1-0.5B-Base"), + os.path.join(models_path_prefix, "LiquidAI/LFM2-1.2B"), ] HF_UNSUPPORTED_MODELS = [ @@ -47,26 +48,28 @@ HF_UNSUPPORTED_MODELS = [ ] V1_SUPPORTED_MODELS = [ - os.path.join(models_path_prefix, "state-spaces/mamba-130m-hf"), - os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-dev"), - os.path.join(models_path_prefix, "yujiepan/mamba2-codestral-v0.1-tiny-random"), - os.path.join(models_path_prefix, "Zyphra/Zamba2-1.2B-instruct"), - os.path.join(models_path_prefix, "hmellor/tiny-random-BambaForCausalLM"), - os.path.join(models_path_prefix, "ibm-granite/granite-4.0-tiny-preview"), - os.path.join(models_path_prefix, "tiiuae/Falcon-H1-0.5B-Base"), + os.path.join(models_path_prefix,"state-spaces/mamba-130m-hf"), + os.path.join(models_path_prefix,"ai21labs/Jamba-tiny-dev"), + os.path.join(models_path_prefix,"yujiepan/mamba2-codestral-v0.1-tiny-random"), + os.path.join(models_path_prefix,"Zyphra/Zamba2-1.2B-instruct"), + os.path.join(models_path_prefix,"hmellor/tiny-random-BambaForCausalLM"), + os.path.join(models_path_prefix,"ibm-granite/granite-4.0-tiny-preview"), + os.path.join(models_path_prefix,"tiiuae/Falcon-H1-0.5B-Base"), + os.path.join(models_path_prefix,"LiquidAI/LFM2-1.2B"), +] + +FULL_CUDA_GRAPH_MODELS = [ + os.path.join(models_path_prefix,"ai21labs/Jamba-tiny-dev"), + os.path.join(models_path_prefix,"Zyphra/Zamba2-1.2B-instruct"), ] +V0_UNSUPPORTED_MODELS = [ + os.path.join(models_path_prefix,"LiquidAI/LFM2-1.2B"), +] # Avoid OOM MAX_NUM_SEQS = 4 -# Once we add support for FCG in Mamba1, this list will be removed and tests -# all test cases will use enforce_eager=False -ENFORCE_EAGER_MODELS_V1 = [ - os.path.join(models_path_prefix, "state-spaces/mamba-130m-hf"), - os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-dev"), -] - @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -99,31 +102,23 @@ def test_models( else: hf_outputs = None - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None if model in V1_SUPPORTED_MODELS: - enforce_eager = False - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - - if model in ENFORCE_EAGER_MODELS_V1: - enforce_eager = True - - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - enforce_eager=enforce_eager, - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) else: vllm_v1_outputs = None - if hf_outputs is not None: + if hf_outputs is not None and vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -133,6 +128,7 @@ def test_models( if model in V1_SUPPORTED_MODELS: ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + assert ref_outputs is not None check_logprobs_close( outputs_0_lst=ref_outputs, outputs_1_lst=vllm_v1_outputs, @@ -141,7 +137,7 @@ def test_models( ) -@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_batching( @@ -151,7 +147,6 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -189,29 +184,32 @@ def test_chunked_prefill( max_tokens: int, num_logprobs: int, chunked_prefill_token_size: int, + monkeypatch, ) -> None: max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size - with vllm_runner(model, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy_logprobs(example_prompts, - max_tokens, num_logprobs) + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + with vllm_runner(model, + enable_chunked_prefill=True, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, - enable_chunked_prefill=False, - max_num_seqs=max_num_seqs) as vllm_model: - non_chunked = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + with vllm_runner(model, + enable_chunked_prefill=False, + max_num_seqs=max_num_seqs) as vllm_model: + non_chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) + check_logprobs_close( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -282,25 +280,29 @@ def test_models_preemption_recompute( example_prompts, model: str, max_tokens: int, + monkeypatch, ) -> None: """ Tests that outputs are identical with and w/o preemptions (recompute). """ - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - scheduler = vllm_model.llm.llm_engine.scheduler[0] - scheduler.ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - scheduler.ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + scheduler = vllm_model.llm.llm_engine.scheduler[0] + scheduler.ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + scheduler.ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -376,7 +378,7 @@ def test_distributed_correctness( ) -@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_full_cuda_graph( @@ -403,23 +405,20 @@ def test_full_cuda_graph( else: hf_outputs = None + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - compilation_config={'full_cuda_graph': True}, - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - if hf_outputs is not None: + if hf_outputs is not None and vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -428,6 +427,7 @@ def test_full_cuda_graph( ) ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + assert ref_outputs is not None check_logprobs_close( outputs_0_lst=ref_outputs, outputs_1_lst=vllm_v1_outputs, @@ -463,24 +463,20 @@ def test_fp32_state( else: hf_outputs = None - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32") as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + m.setenv("VLLM_USE_V1", "0") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32", - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + if hf_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 61c5fcab4f8a4bcfd45b41d537f9c602fbd24234..a74ad2aa2597247e1f3063c9ba9ea00ae80326a4 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 4a1f8a53d024c1b00f65afc984a59fc1009f65de..640858125bfcaccb9050815c2168bda4fa28f54a 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 206524d7caad3836b26e701da72bdf4bb4f9b7b5..f473e0ba01ffa2c64403c60eb4ba5573bf6ac0f0 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -13,7 +13,14 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models RERANK_MODELS = [ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + architecture="GemmaForSequenceClassification", + hf_overrides={ + "architectures": + ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": + "no_post_processing", + }), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 @@ -119,22 +126,9 @@ class GemmaMtebEncoder(VllmMtebEncoder): @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, - monkeypatch) -> None: - monkeypatch.setenv("VLLM_USE_V1", "0") - - assert model_info.architecture == "GemmaForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["GemmaForSequenceClassification"], - "classifier_from_token": ["Yes"], - "method": "no_post_processing", - } - } +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: mteb_test_rerank_models(GemmaRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs, vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 1b576fa33fb8ab767c43d7afce1bb61ac36e0501..ceeb173777e53fcd040a4ed8ed2c317a4d276ca0 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -13,14 +13,6 @@ from vllm.platforms import current_platform from ...utils import check_embeddings_close, check_transformers_version -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.mark.parametrize( "model", [ @@ -35,21 +27,15 @@ def v1(run_with_both_engines): os.path.join(models_path_prefix, "intfloat/e5-mistral-7b-instruct"), # CPU v1 doesn't support sliding window marks=[pytest.mark.core_model]), - # the qwen models interfere with each other (see PR - # https://github.com/vllm-project/vllm/pull/18720). - # To avoid this problem, for now we skip v0 since it will be - # deprecated anyway. pytest.param(os.path.join(models_path_prefix, "ssmits/Qwen2-7B-Instruct-embed-base"), - marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), + marks=[pytest.mark.cpu_model]), # [Encoder-only] pytest.param(os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5"), marks=[pytest.mark.core_model]), pytest.param(os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2")), pytest.param(os.path.join(models_path_prefix, "intfloat/multilingual-e5-small")), - pytest.param(os.path.join(models_path_prefix, "Alibaba-NLP/gte-Qwen2-1.5B-instruct"), - marks=[pytest.mark.skip_v1]), + pytest.param(os.path.join(models_path_prefix, "Alibaba-NLP/gte-Qwen2-1.5B-instruct")), # [Cross-Encoder] - pytest.param(os.path.join(models_path_prefix, "sentence-transformers/stsb-roberta-base-v2"), - marks=[pytest.mark.skip_v1]), + pytest.param(os.path.join(models_path_prefix, "sentence-transformers/stsb-roberta-base-v2")), ], ) diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index 6c51a10a26025c9f377a6722ed1eb256199a77ab..8248fddf93ac6ff1895136acd319c0cf54b2a33e 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -18,6 +18,7 @@ from ....utils import models_path_prefix MODEL_NAME = os.path.join(models_path_prefix, "parasail-ai/GritLM-7B-vllm") MAX_MODEL_LEN = 4000 +ATOL = 0.002 def _arr(arr): @@ -101,16 +102,16 @@ def get_test_data(): def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]): cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) - assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=0.001) + assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=ATOL) cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1]) - assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=0.001) + assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=ATOL) cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0]) - assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=0.001) + assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=ATOL) cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) - assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=0.001) + assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=ATOL) def test_gritlm_offline_embedding(vllm_runner): diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index f805a64103c06966f3e9fbd78aaef7577c8a02aa..9911620c018ef82fbabc8bb5e7dc1f71233bb717 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any import pytest @@ -33,12 +32,15 @@ MODELS = [ ########### NewModel CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), ########### Qwen2ForCausalLM LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", @@ -60,11 +62,16 @@ MODELS = [ ] RERANK_MODELS = [ - # classifier_pooling: mean CLSPoolingRerankModelInfo( + # classifier_pooling: mean "Alibaba-NLP/gte-reranker-modernbert-base", architecture="ModernBertForSequenceClassification", enable_test=True), + CLSPoolingRerankModelInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + architecture="GteNewForSequenceClassification", + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + enable_test=True), ] @@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - - mteb_test_embed_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) @@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts, vllm_extra_kwargs) + example_prompts) @pytest.mark.parametrize("model_info", RERANK_MODELS) diff --git a/tests/models/language/pooling/test_multilabel_classification_support.py b/tests/models/language/pooling/test_multilabel_classification_support.py new file mode 100644 index 0000000000000000000000000000000000000000..45366f20941447fbc64b82e83d2645b166ac002b --- /dev/null +++ b/tests/models/language/pooling/test_multilabel_classification_support.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + + +@pytest.mark.parametrize( + "model", + ["Rami/multi-label-class-classification-on-github-issues"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index 480bd5e4567cbd8d51ed33c84865e78bd2855859..73823deeff4e0a6017a134eafb85174037932317 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -10,12 +10,20 @@ from tests.conftest import HfRunner from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +mxbai_rerank_hf_overrides = { + "architectures": ["Qwen2ForSequenceClassification"], + "classifier_from_token": ["0", "1"], + "method": "from_2_way_softmax", +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=False) ] @@ -71,13 +79,4 @@ class MxbaiRerankerHfRunner(HfRunner): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "Qwen2ForSequenceClassification": - vllm_extra_kwargs["hf_overrides"] = { - "architectures": ["Qwen2ForSequenceClassification"], - "classifier_from_token": ["0", "1"], - "method": "from_2_way_softmax", - } - - mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 37f5566a330d0dc0fbcbd625ebc0dab54e848fce..5dd2d9eae911567f4831f48f137e6359aa413f62 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -11,12 +11,20 @@ from tests.utils import multi_gpu_test from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +qwen3_reranker_hf_overrides = { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=False) ] @@ -74,18 +82,7 @@ class Qwen3RerankerHfRunner(HfRunner): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - } - } - - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", RERANK_MODELS) @@ -96,16 +93,8 @@ def test_rerank_models_mteb_tp(vllm_runner, assert model_info.architecture == "Qwen3ForSequenceClassification" vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, "tensor_parallel_size": 2, } - mteb_test_rerank_models(Qwen3RerankerHfRunner, - vllm_runner, - model_info, - vllm_extra_kwargs, - atol=1.2e-2) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, + vllm_extra_kwargs) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index beafa0aed9862ff44ee78c5e08abd29f3ad41742..08722ac98b7ed256ddb3fad5a5ef869c24530ad9 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -13,14 +13,6 @@ from ....conftest import HfRunner from ...utils import check_transformers_version -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture def math_step_prompts(): # ruff: noqa: E501 diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index 488e6499b1913b536d79ec0a6ed739da34d90ac9..d62e8a8749b7b3ea631d299f3e822d16d9061292 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -27,15 +27,6 @@ TEXTS_2 = [ "The capital of Germany is Berlin.", ] - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - DTYPE = "half" diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..51ddbcc5ab249646e34955d710684f17df76dfaa --- /dev/null +++ b/tests/models/language/pooling/test_st_projector.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from .mteb_utils import mteb_test_embed_models + +# ST models with projector (Dense) layers +ST_PROJECTOR_MODELS = [ + CLSPoolingEmbedModelInfo( + "TencentBAC/Conan-embedding-v1", + architecture="BertModel", + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + + mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 84e8836b33ad5e81c290d9df73a2fe488662a5ba..8ba5f358003c972211b63162060cef98eb91885d 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -192,23 +192,21 @@ VLM_TEST_SETTINGS = { }, marks=[pytest.mark.core_model], ), - # FIXME(Isotr0py): Enable this test after - # https://github.com/huggingface/transformers/pull/39470 released - # "idefics3-transformers": VLMTestInfo( - # models=["HuggingFaceTB/SmolVLM-256M-Instruct"], - # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - # prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 - # img_idx_to_prompt=lambda idx: "<image>", - # max_model_len=8192, - # max_num_seqs=2, - # auto_cls=AutoModelForImageTextToText, - # hf_output_post_proc=model_utils.idefics3_trunc_hf_output, - # image_size_factors=[(0.25, 0.5, 1.0)], - # vllm_runner_kwargs={ - # "model_impl": "transformers", - # }, - # marks=[pytest.mark.core_model], - # ), + "idefics3-transformers": VLMTestInfo( + models=["HuggingFaceTB/SmolVLM-256M-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + hf_output_post_proc=model_utils.idefics3_trunc_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + }, + marks=[pytest.mark.core_model], + ), # Pixel values from processor are not 4D or 5D arrays "qwen2_5_vl-transformers": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], @@ -225,21 +223,6 @@ VLM_TEST_SETTINGS = { }, marks=[large_gpu_mark(min_gb=32)], ), - # Check "auto" with fallback to transformers - "internvl-transformers": VLMTestInfo( - models=["OpenGVLab/InternVL3-1B-hf"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", - max_model_len=4096, - use_tokenizer_eos=True, - image_size_factors=[(0.25, 0.5, 1.0)], - vllm_runner_kwargs={ - "model_impl": "auto", - }, - auto_cls=AutoModelForImageTextToText, - marks=[pytest.mark.core_model], - ), #### Extended model tests "aria": VLMTestInfo( models=[os.path.join(models_path_prefix, "rhymes-ai/Aria")], @@ -340,10 +323,6 @@ VLM_TEST_SETTINGS = { vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], - # FIXME(Isotr0py): This model is broken in Transformers v4.54.1, we - # should enable this again after the fix is released: - # https://github.com/huggingface/transformers/pull/39915 - marks=[pytest.mark.skip("HF model is broken")], ), "gemma3": VLMTestInfo( models=[os.path.join(models_path_prefix, "google/gemma-3-4b-it")], @@ -464,6 +443,20 @@ VLM_TEST_SETTINGS = { use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "intern_vl-hf": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO, + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", + video_idx_to_prompt=lambda idx: "<video>", + max_model_len=8192, + use_tokenizer_eos=True, + auto_cls=AutoModelForImageTextToText, + ), "kimi_vl": VLMTestInfo( models=["moonshotai/Kimi-VL-A3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -624,6 +617,23 @@ VLM_TEST_SETTINGS = { hf_model_kwargs={"llm_attn_implementation": "sdpa"}, patch_hf_runner=model_utils.ovis_patch_hf_runner, ), + "ovis2_5": VLMTestInfo( + models=["AIDC-AI/Ovis2.5-2B"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + video_idx_to_prompt=lambda idx: "<video>\n", + max_model_len=4096, + max_num_seqs=2, + dtype="half", + num_logprobs=10, + patch_hf_runner=model_utils.ovis2_5_patch_hf_runner, + hf_model_kwargs={"revision": "refs/pr/5"}, + ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py index d30ff2c3405c7c36bb1f3435fdd001c10751645a..b1b24c40a05e2ce22b0bc0c0531fe2c800565eb5 100644 --- a/tests/models/multimodal/generation/test_mllama.py +++ b/tests/models/multimodal/generation/test_mllama.py @@ -6,6 +6,7 @@ from typing import Optional, overload import os import pytest import torch +from packaging.version import Version from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer from transformers import __version__ as TRANSFORMERS_VERSION @@ -289,8 +290,8 @@ def clear_cache(): @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, model, sizes, dtype, max_tokens, @@ -321,8 +322,8 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs, @@ -374,8 +375,8 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs, @@ -418,8 +419,8 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.skipif( - TRANSFORMERS_VERSION == "4.55.0", - reason="Transformers v4.55.0 has a regression issue on mllama, " + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " "see: https://github.com/huggingface/transformers/pull/40083") def test_models_distributed( hf_runner, diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index c53243b42e384a9e65eb2c9dadcad7527460c258..e369416fc49ccda5a7d5c09bd6cab3f9b2ba372a 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -1,12 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom input builders for edge-cases in different models.""" -from io import BytesIO from typing import Callable -import requests -from PIL import Image - +from vllm.assets.image import ImageAsset from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import (rescale_video_size, resize_video, sample_frames_from_video) @@ -118,9 +115,9 @@ def different_patch_input_cases_internvl(): def windows_attention_image_qwen2_5_vl(): - # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 - image_url = "https://aomediacodec.github.io/av1-avif/testFiles/Link-U/hato.jpg" - image = Image.open(BytesIO(requests.get(image_url).content)) + + # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 # noqa: E501 + image = ImageAsset("hato").pil_image question = "Describe the image." img_prompt = "<|vision_start|><|image_pad|><|vision_end|>" diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 5e8dac6bce96a245edc2007a1dd5bd356efff0b1..8b7d051218f1492b8174c570a67719d374ed2700 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -10,6 +10,7 @@ from typing import Optional, Union import numpy as np import numpy.typing as npt +import PIL.Image import pytest import regex as re import torch @@ -19,7 +20,6 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature, from transformers.video_utils import VideoMetadata from vllm.sequence import SampleLogprobs -from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets @@ -343,7 +343,6 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4V.""" hf_processor = hf_model.processor - patch_padding_side(hf_processor) def processor(*args, text="", images=None, **kwargs): if images is None: @@ -812,6 +811,63 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for Ovis2.""" + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.llm.get_output_embeddings() + + def processor(*args, text="", images=None, videos=None, **kwargs): + if images is None: + images = [] + else: + images = [images] if isinstance(images, Image) else images + if videos is None: + videos = [] + else: + videos = [videos] if isinstance(videos, np.ndarray) else videos + videos = [[PIL.Image.fromarray(frame) for frame in vid] + for vid in videos] + + prompt_start_and_end = { + "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), + "llama": + ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), + } + for start, end in prompt_start_and_end.values(): + if start in text and end in text: + text = text.split(start)[1].split(end)[0] + break + + images_message = [{"type": "image", "image": img} for img in images] + videos_message = [{"type": "video", "video": vid} for vid in videos] + + messages = [{ + "role": + "user", + "content": [ + *images_message, + *videos_message, + { + "type": "text", + "text": text + }, + ], + }] + + input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs( + messages=messages, enable_thinking=True) + inputs = { + "inputs": input_ids, + "pixel_values": pixel_values, + "grid_thws": grid_thws, + } + return BatchFeature(data=inputs, tensor_type="pt") + + hf_model.processor = processor + return hf_model + + def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner for Qwen2.5-Omni.""" thinker = hf_model.model.thinker diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 24b4aa4bc7679ebce10bb2f818ee9174ab4f6568..bf9a83bb82b673140f4a93ad107823399bd50f00 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -15,8 +15,9 @@ from PIL import Image from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens) @@ -65,6 +66,8 @@ def _test_processing_correctness( revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + # Ensure that the cache can fit all of the data + mm_processor_cache_gb=2048, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) @@ -73,8 +76,7 @@ def _test_processing_correctness( model_config, tokenizer=cached_tokenizer_from_config(model_config), ) - # Ensure that it can fit all of the data - cache = ProcessingCache(capacity_gb=2048) + cache = MultiModalProcessorOnlyCache(model_config) processing_info = factories.info(ctx) supported_mm_limits = processing_info.get_supported_mm_limits() @@ -104,7 +106,7 @@ def _test_processing_correctness( partial(random_video, rng, min_frames=2, - max_frames=8, + max_frames=16, min_wh=128, max_wh=256), "audio": @@ -162,8 +164,10 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { + "donut": False, "mllama": False, "ovis": False, + "ovis2_5": False, "paligemma": False, "ultravox": False, "whisper": False, @@ -265,60 +269,72 @@ def _test_processing_correctness_one( # yapf: disable @pytest.mark.parametrize("model_id", [ - os.path.join(models_path_prefix, "rhymes-ai/Aria"), - os.path.join(models_path_prefix, "CohereForAI/aya-vision-8b"), - os.path.join(models_path_prefix, "Salesforce/blip2-opt-2.7b"), - os.path.join(models_path_prefix, "facebook/chameleon-7b"), - os.path.join(models_path_prefix, "deepseek-ai/deepseek-vl2-tiny"), - os.path.join(models_path_prefix, "microsoft/Florence-2-base"), - os.path.join(models_path_prefix, "adept/fuyu-8b"), - os.path.join(models_path_prefix, "google/gemma-3-4b-it"), - os.path.join(models_path_prefix, "google/gemma-3n-E2B-it"), - os.path.join(models_path_prefix, "zai-org/glm-4v-9b"), - os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking"), - os.path.join(models_path_prefix, "ibm-granite/granite-speech-3.3-2b"), - os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-800m"), - os.path.join(models_path_prefix, "internlm/Intern-S1"), - os.path.join(models_path_prefix, "OpenGVLab/InternVL2-1B"), - os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B"), - os.path.join(models_path_prefix, "HuggingFaceM4/Idefics3-8B-Llama3"), - os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct"), - os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Instruct"), - os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"), - os.path.join(models_path_prefix, "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"), - os.path.join(models_path_prefix, "llava-hf/llava-1.5-7b-hf"), - os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf"), - os.path.join(models_path_prefix, "llava-hf/LLaVA-NeXT-Video-7B-hf"), - os.path.join(models_path_prefix, "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), - os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision-Instruct"), - os.path.join(models_path_prefix, "TIGER-Lab/Mantis-8B-siglip-llama3"), - os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"), - os.path.join(models_path_prefix, "openbmb/MiniCPM-o-2_6"), - os.path.join(models_path_prefix, "openbmb/MiniCPM-V-2_6"), - os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-VL-01"), - os.path.join(models_path_prefix, "allenai/Molmo-7B-D-0924"), - os.path.join(models_path_prefix, "allenai/Molmo-7B-O-0924"), - os.path.join(models_path_prefix, "nvidia/NVLM-D-72B"), - os.path.join(models_path_prefix, "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"), - os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Gemma2-9B"), - os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Llama3.2-3B"), - os.path.join(models_path_prefix, "AIDC-AI/Ovis2-1B"), - os.path.join(models_path_prefix, "google/paligemma-3b-mix-224"), - os.path.join(models_path_prefix, "google/paligemma2-3b-ft-docci-448"), - os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"), - os.path.join(models_path_prefix, "microsoft/Phi-4-multimodal-instruct"), - os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"), - os.path.join(models_path_prefix, "mistral-community/pixtral-12b"), - os.path.join(models_path_prefix, "Qwen/Qwen-VL-Chat"), - os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct"), - os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-3B-Instruct"), - os.path.join(models_path_prefix, "Qwen/Qwen2-Audio-7B-Instruct"), - os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-3B"), - os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B"), - os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"), - os.path.join(models_path_prefix, "openai/whisper-large-v3"), - os.path.join(models_path_prefix, "omni-research/Tarsier-7b"), - os.path.join(models_path_prefix, "omni-research/Tarsier2-Recap-7b"), + os.path.join(models_path_prefix,"rhymes-ai/Aria"), + os.path.join(models_path_prefix,"CohereForAI/aya-vision-8b"), + os.path.join(models_path_prefix,"Salesforce/blip2-opt-2.7b"), + os.path.join(models_path_prefix,"facebook/chameleon-7b"), + os.path.join(models_path_prefix,"CohereLabs/command-a-vision-07-2025"), + os.path.join(models_path_prefix,"deepseek-ai/deepseek-vl2-tiny"), + os.path.join(models_path_prefix,"naver-clova-ix/donut-base-finetuned-docvqa"), + os.path.join(models_path_prefix,"baidu/ERNIE-4.5-VL-28B-A3B-PT"), + os.path.join(models_path_prefix,"microsoft/Florence-2-base"), + os.path.join(models_path_prefix,"adept/fuyu-8b"), + os.path.join(models_path_prefix,"google/gemma-3-4b-it"), + os.path.join(models_path_prefix,"google/gemma-3n-E2B-it"), + os.path.join(models_path_prefix,"zai-org/glm-4v-9b"), + os.path.join(models_path_prefix,"zai-org/GLM-4.1V-9B-Thinking"), + os.path.join(models_path_prefix,"zai-org/GLM-4.5V"), + os.path.join(models_path_prefix,"ibm-granite/granite-speech-3.3-2b"), + os.path.join(models_path_prefix,"h2oai/h2ovl-mississippi-800m"), + os.path.join(models_path_prefix,"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"), + os.path.join(models_path_prefix,"HuggingFaceM4/Idefics3-8B-Llama3"), + os.path.join(models_path_prefix,"internlm/Intern-S1)", + os.path.join(models_path_prefix,"OpenGVLab/InternVL2-1B"), + os.path.join(models_path_prefix,"OpenGVLab/InternVL3-1B)", + os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-1B"), + os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview"), + os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-30B-A3B"), + os.path.join(models_path_prefix,"Kwai-Keye/Keye-VL-8B-Preview"), + os.path.join(models_path_prefix,"moonshotai/Kimi-VL-A3B-Instruct"), + os.path.join(models_path_prefix,"meta-llama/Llama-4-Scout-17B-16E-Instruct"), + os.path.join(models_path_prefix,"llava-hf/llava-1.5-7b-hf"), + os.path.join(models_path_prefix,"llava-hf/llava-v1.6-mistral-7b-hf"), + os.path.join(models_path_prefix,"llava-hf/LLaVA-NeXT-Video-7B-hf"), + os.path.join(models_path_prefix,"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), + os.path.join(models_path_prefix,"meta-llama/Llama-3.2-11B-Vision-Instruct"), + os.path.join(models_path_prefix,"TIGER-Lab/Mantis-8B-siglip-llama3"), + os.path.join(models_path_prefix,"openbmb/MiniCPM-Llama3-V-2_5"), + os.path.join(models_path_prefix,"openbmb/MiniCPM-o-2_6"), + os.path.join(models_path_prefix,"openbmb/MiniCPM-V-2_6"), + os.path.join(models_path_prefix,"MiniMaxAI/MiniMax-VL-01"), + os.path.join(models_path_prefix,"allenai/Molmo-7B-D-0924"), + os.path.join(models_path_prefix,"allenai/Molmo-7B-O-0924"), + os.path.join(models_path_prefix,"nvidia/NVLM-D-72B"), + os.path.join(models_path_prefix,"nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"), + os.path.join(models_path_prefix,"AIDC-AI/Ovis1.6-Gemma2-9B"), + os.path.join(models_path_prefix,"AIDC-AI/Ovis1.6-Llama3.2-3B"), + os.path.join(models_path_prefix,"AIDC-AI/Ovis2-1B"), + os.path.join(models_path_prefix,"AIDC-AI/Ovis2.5-2B"), + os.path.join(models_path_prefix,"google/paligemma-3b-mix-224"), + os.path.join(models_path_prefix,"google/paligemma2-3b-ft-docci-448"), + os.path.join(models_path_prefix,"microsoft/Phi-3.5-vision-instruct"), + os.path.join(models_path_prefix,"microsoft/Phi-4-multimodal-instruct"), + os.path.join(models_path_prefix,"mistralai/Pixtral-12B-2409"), + os.path.join(models_path_prefix,"mistral-community/pixtral-12b"), + os.path.join(models_path_prefix,"Qwen/Qwen-VL-Chat"), + os.path.join(models_path_prefix,"Qwen/Qwen2-VL-2B-Instruct"), + os.path.join(models_path_prefix,"Qwen/Qwen2.5-VL-3B-Instruct"), + os.path.join(models_path_prefix,"Qwen/Qwen2-Audio-7B-Instruct"), + os.path.join(models_path_prefix,"Qwen/Qwen2.5-Omni-3B"), + os.path.join(models_path_prefix,"YannQi/R-4B"), + os.path.join(models_path_prefix,"Skywork/Skywork-R1V-38B"), + os.path.join(models_path_prefix,"HuggingFaceTB/SmolVLM2-2.2B-Instruct"), + os.path.join(models_path_prefix,"stepfun-ai/step3"), + os.path.join(models_path_prefix,"fixie-ai/ultravox-v0_5-llama-3_2-1b"), + os.path.join(models_path_prefix,"openai/whisper-large-v3"), + os.path.join(models_path_prefix,"omni-research/Tarsier-7b"), + os.path.join(models_path_prefix,"omni-research/Tarsier2-Recap-7b"), + os.path.join(models_path_prefix,"mistralai/Voxtral-Mini-3B-2507"), ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @@ -372,10 +388,16 @@ def _assert_inputs_equal( if ignore_mm_keys is None: ignore_mm_keys = set() - assert "mm_kwargs" in a and "mm_kwargs" in b, msg + a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"} + b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"} + + assert a_rest == b_rest, msg + + a_data = a["mm_kwargs"].get_data() + b_data = b["mm_kwargs"].get_data() for key in ignore_mm_keys: - a["mm_kwargs"].pop(key, None) - b["mm_kwargs"].pop(key, None) + a_data.pop(key, None) + b_data.pop(key, None) - assert a == b, msg + assert a_data == b_data, msg diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index a6d900ec5d895904702634591cef49a94a7a32e8..a49842e1099c2f80e0756d7208d907bd17b16225 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -45,7 +45,8 @@ def test_processor_override( video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token) video_tok_count = processed_inputs["prompt_token_ids"].count( video_token_id) - grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0] + grid_t, _, _ = processed_inputs["mm_kwargs"].get_data( + )["video_grid_thw"][0] assert grid_t == expected_grid_t assert video_tok_count == expected_toks_per_frame * grid_t diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 76e4acc67d4d5f14275388e10b4b47a705997177..1adfe21352c41be920d163830c34150b0f0ff988 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -108,7 +108,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index c662714dbdaa7666dc57cc366ebbdb88a9c53a6a..b938a990c9d7d9870d2dca1c33fea0eeb34b98de 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -70,7 +70,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index dce313e42d67f060cf2e58f26341f95eeea3eb28..69ebce784ce8769afb445a38f4b26ed7988d0ce4 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -53,14 +53,14 @@ def test_processor_override( prompt = encode_tokens(tokenizer, prompt) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) - mm_kwargs = processed_inputs["mm_kwargs"] + mm_data = processed_inputs["mm_kwargs"].get_data() # place holder replacements prompt_token_ids = processed_inputs["prompt_token_ids"] assert prompt_token_ids.count(config.boi_token_index) == num_imgs assert prompt_token_ids.count(config.eoi_token_index) == num_imgs assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs - aspect_ratios = mm_kwargs["aspect_ratios"] + aspect_ratios = mm_data["aspect_ratios"] num_x_separators = num_y_separators = 0 for tiles_y, tiles_x in aspect_ratios: if tiles_x * tiles_y > 1: @@ -82,6 +82,6 @@ def test_processor_override( num_patches_per_chunk = processor.info.get_patch_per_chunk( config.vision_config) assert prompt_token_ids.count(config.image_token_index) \ - == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk - assert mm_kwargs["pixel_values"].shape[0] \ - == mm_kwargs["patches_per_image"].sum() \ No newline at end of file + == sum(mm_data["patches_per_image"]) * num_patches_per_chunk + assert len(mm_data["pixel_values"]) \ + == sum(mm_data["patches_per_image"]) diff --git a/tests/models/multimodal/processing/test_mllama.py b/tests/models/multimodal/processing/test_mllama.py index 5f428ccafc16edfeec47db765158d10d4399735e..e491619fccff98bdf5ad30a435ffe1cd3639e9fc 100644 --- a/tests/models/multimodal/processing/test_mllama.py +++ b/tests/models/multimodal/processing/test_mllama.py @@ -51,18 +51,18 @@ def test_profiling( encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids) ] * max_num_seqs - mm_kwargs = processor.apply( + mm_data = processor.apply( prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), - )["mm_kwargs"] + )["mm_kwargs"].get_data() # Get the actual number of encoder tokens for each sample. # Because attn_metadata.encoder_seq_lens only counts the last # group of images for each sample, which is used to cheat the # block manager to allocate blocks for those images only. # See MllamaMultiModalProcessor for more details. - num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")] + num_tiles = [[t] for t in mm_data.pop("num_tiles")] num_tokens_per_tile = calc_token_per_chunk(image_size) actual_encoder_seq_lens = [ sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index f3871b60c3f645a20e01a0114109a978e7481cdd..3be77b5da63f22b33a30ec5aee448202151dc814 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int): hf_config = ctx.get_hf_config(Llama4Config) - mm_kwargs = processor.apply( + mm_data = processor.apply( prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), - )["mm_kwargs"] + )["mm_kwargs"].get_data() image_size = hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size downsample_ratio = int( round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio - chunks_per_image = prod(mm_kwargs["patches_per_image"]) + chunks_per_image = prod(mm_data["patches_per_image"]) total_num_patches = chunks_per_image * tokens_per_patch - num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][ - 0][1] # x-y seperator tokens + num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][ + 1] # x-y seperator tokens total_tokens = total_num_patches.item() + num_tiles.item( ) + 3 # image start, image, image end diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index 6fbbab0d26124898264180531d95b62d52ad74e4..d9f1965a053df68b4918c7c2cbf14acbab86e4c4 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -70,7 +70,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<image>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape) assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index e8464fb8618747616981bbf737c05b8cbee17a09..30f8953114a9231914f40d7886d0474e4ae9c44b 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -50,7 +50,8 @@ def test_processor_override( hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values"].shape assert img_tok_count == expected_toks_per_img * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs diff --git a/tests/models/multimodal/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py similarity index 63% rename from tests/models/multimodal/test_tensor_schema.py rename to tests/models/multimodal/processing/test_tensor_schema.py index 036624431c20b85e9c81c3b8e786404e89d4cb84..1a11fa3d2b824d69f8a874ca6df85afaaa86aa55 100644 --- a/tests/models/multimodal/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -1,36 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile from collections.abc import Iterable +from contextlib import contextmanager from functools import partial from typing import Any, Union -from unittest.mock import patch import numpy as np import pytest +import torch.nn as nn from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, UserMessage) from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image -from vllm.config import ModelConfig -from vllm.engine.llm_engine import LLMEngine as V0LLMEngine +from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config +from vllm.distributed import (cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel) from vllm.inputs import InputProcessingContext -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads -from vllm.v1.core.kv_cache_utils import get_kv_cache_config -from vllm.v1.engine.core import EngineCore as V1EngineCore +from vllm.utils import is_list_of -from ...conftest import VllmRunner -from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS -from ..utils import dummy_hf_overrides +from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS +from ...utils import dummy_hf_overrides ARCH_TO_SKIP = { "MolmoForCausalLM": "incompatible requirements", - "MiniMaxVL01ForConditionalGeneration": "broken model", } ARCH_NEEDS_EXTRAS = [ "InternVLChatModel", @@ -39,7 +39,12 @@ ARCH_NEEDS_EXTRAS = [ "MiniCPMV", "PaliGemmaForConditionalGeneration", ] -REPO_ID_TO_SKIP = {"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test"} +REPO_ID_TO_SKIP = { + "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", + # FIXME(Isotr0py): enable GPT-OSS based InternVL3.5 model + # after support PP for GPT-OSS + "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview": "Broken model", +} ImageInput = list[Image.Image] VideoInput = Union[list[Image.Image], list[np.ndarray], @@ -128,11 +133,32 @@ def create_batched_mm_kwargs( )["mm_kwargs"] items = [ item for modality in supported_mm_limits - for item in mm_kwargs.get_items(modality) + for item in mm_kwargs[modality] ] return group_mm_kwargs_by_modality(items) +@contextmanager +def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(tensor_model_parallel_size=1) + vllm_config = VllmConfig(model_config=model_config) + with set_current_vllm_config(vllm_config=vllm_config): + with set_default_torch_dtype(model_config.dtype): + model = model_cls(vllm_config=vllm_config) + yield model + + del model + cleanup_dist_env_and_memory() + + def get_model_id_to_test( model_arch_list: Iterable[str]) -> list[tuple[str, str]]: filtered_results = [] @@ -148,12 +174,10 @@ def get_model_id_to_test( return filtered_results -@pytest.mark.core_model @pytest.mark.parametrize( "model_arch, model_id", get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys())) -def test_model_tensor_schema(model_arch: str, model_id: str, - vllm_runner: type[VllmRunner], monkeypatch): +def test_model_tensor_schema(model_arch: str, model_id: str): if model_arch in ARCH_TO_SKIP: pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") if model_id in REPO_ID_TO_SKIP: @@ -174,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str, tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, + hf_overrides=hf_overrides_fn, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] - if not any( - hasattr(model_cls, f"_parse_and_validate_{m}_input") - for m in ["image", "video", "audio"]): + inputs_parse_methods = [] + for attr_name in dir(model_cls): + attr = getattr(model_cls, attr_name) + if hasattr(attr, "__annotations__"): + return_type = attr.__annotations__.get("return", None) + if return_type is not None and "Input" in str(return_type): + inputs_parse_methods.append(attr_name) + + if not any(inputs_parse_methods): pytest.skip(f"{model_arch} does not support tensor schema validation.") ctx = InputProcessingContext( @@ -194,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str, modality: 3 if limit is None else limit for modality, limit in supported_mm_limits.items() } - - # Avoid calling model.forward() - def _initialize_kv_caches_v0(self) -> None: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 - - def _initialize_kv_caches_v1(self, vllm_config): - kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_config( - vllm_config, - kv_cache_specs[0], - 10 * GiB_bytes, - ) - - # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config - return 1, 0, scheduler_kv_cache_config - - with (patch.object(V0LLMEngine, "_initialize_kv_caches", - _initialize_kv_caches_v0), - patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1), monkeypatch.context() as m): - m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - if model_info.v0_only: - m.setenv("VLLM_USE_V1", "0") - - # TODO(Isotr0py): Can we avoid initializing engine? - with ( - set_default_torch_num_threads(1), - vllm_runner( - model_id, - tokenizer_name=model_info.tokenizer, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - max_model_len=model_info.max_model_len, - load_format="dummy", - hf_overrides=hf_overrides_fn, - limit_mm_per_prompt=limit_mm_per_prompt, - enforce_eager=True, - ) as vllm_model, - ): - model_config = vllm_model.llm.llm_engine.model_config - llm_engine = vllm_model.llm.llm_engine - - if hasattr(llm_engine, "processor"): - # v1 processor - mm_registry = llm_engine.processor.mm_registry - else: - # v0 input_preprocessor - mm_registry = llm_engine.input_preprocessor.mm_registry - - processor = mm_registry.create_processor(model_config) - - def validate_model_input(model, modality: str, - mm_kwargs: MultiModalKwargs): - method_name = f"_parse_and_validate_{modality}_input" - if hasattr(model, method_name): - getattr(model, method_name)(**mm_kwargs) - - for modality, _, mm_kwargs in create_batched_mm_kwargs( - model_config, processor): - valid_func = partial(validate_model_input, - modality=modality, - mm_kwargs=mm_kwargs) - vllm_model.apply_model(valid_func) + model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt + processor = factories.build_processor(ctx, cache=None) + + with initialize_dummy_model(model_cls, model_config) as model: + for modality, _, mm_kwargs in create_batched_mm_kwargs( + model_config, processor): + for method_name in inputs_parse_methods: + print(f"Testing `{method_name}` with modality={modality} " + f"and mm_kwargs{list(mm_kwargs.keys())}") + getattr(model, method_name)(modality=modality, **mm_kwargs) diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index f5648b0f6335f786f8fdb07821bab9cb7c4d8f75..2c855fbaf6cabed7fa831706be52194f0320c0f9 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -33,7 +33,7 @@ from ...utils import models_path_prefix # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @@ -58,9 +58,6 @@ def test_models( numerical sensitive kernels. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): pytest.skip( f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") diff --git a/tests/models/registry.py b/tests/models/registry.py index b761cac8abf42c9a9c853e365d410dbc93b83d8e..0264d0f2ddf25ee2e4cbff39eeb4a0a9841f93f6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -141,7 +141,10 @@ class _HfExamplesInfo: # yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] - "AquilaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/AquilaChat-7B"), + "ApertusForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"swiss-ai/Apertus-8B"), + min_transformers_version="4.56.0", + trust_remote_code=True), + "AquilaModel": _HfExamplesInfo(os.path.join(models_path_prefix,"BAAI/AquilaChat-7B"), trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/AquilaChat2-7B"), trust_remote_code=True), @@ -219,10 +222,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "HunYuanDenseV1ForCausalLM":_HfExamplesInfo(os.path.join(models_path_prefix, "tencent/Hunyuan-7B-Instruct-0124"), trust_remote_code=True, is_available_online=False), - "HCXVisionForCausalLM": _HfExamplesInfo( - os.path.join(models_path_prefix, "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"), - trust_remote_code=True), - "InternLMForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "internlm/internlm-chat-7b"), + "InternLMForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"internlm/internlm-chat-7b"), trust_remote_code=True), "InternLM2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "internlm/internlm2-chat-7b"), trust_remote_code=True), @@ -237,11 +237,13 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "tiny": os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-dev"), "random": os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-random"), # noqa: E501 }), - "LlamaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), - extras={"guard": os.path.join(models_path_prefix, "meta-llama/Llama-Guard-3-1B"), # noqa: E501 - "hermes": os.path.join(models_path_prefix, "NousResearch/Hermes-3-Llama-3.1-8B"), # noqa: E501 - "fp8": os.path.join(models_path_prefix, "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8")}), # noqa: E501 - "LLaMAForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "decapoda-research/llama-7b-hf"), + "Lfm2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"LiquidAI/LFM2-1.2B"), + min_transformers_version="4.54"), + "LlamaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"meta-llama/Llama-3.2-1B-Instruct"), + extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 + "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 + "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501 + "LLaMAForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"decapoda-research/llama-7b-hf"), is_available_online=False), "Llama4ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 is_available_online=False), @@ -293,18 +295,20 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "Qwen2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen2-0.5B-Instruct"), extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501 - "Qwen2MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen1.5-MoE-A2.7B-Chat")), - "Qwen3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-8B")), - "Qwen3MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-30B-A3B")), - "RWForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "tiiuae/falcon-40b")), - "SmolLM3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceTB/SmolLM3-3B")), - "StableLMEpochForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stabilityai/stablelm-zephyr-3b")), # noqa: E501 - "StableLmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stabilityai/stablelm-3b-4e1t")), - "Starcoder2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "bigcode/starcoder2-3b")), - "Step3TextForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stepfun-ai/step3"), - trust_remote_code=True, - is_available_online=False), - "SolarForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "upstage/solar-pro-preview-instruct"), + "Qwen2MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen1.5-MoE-A2.7B-Chat")), + "Qwen3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen3-8B")), + "Qwen3MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"Qwen/Qwen3-30B-A3B")), + "RWForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"tiiuae/falcon-40b")), + "SeedOssForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"ByteDance-Seed/Seed-OSS-36B-Instruct"), # noqa: E501 + trust_remote_code=True, + is_available_online=False), + "SmolLM3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"HuggingFaceTB/SmolLM3-3B")), + "StableLMEpochForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stabilityai/stablelm-zephyr-3b")), # noqa: E501 + "StableLmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stabilityai/stablelm-3b-4e1t")), + "Starcoder2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"bigcode/starcoder2-3b")), + "Step3TextForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"stepfun-ai/step3"), + trust_remote_code=True), + "SolarForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"upstage/solar-pro-preview-instruct"), trust_remote_code=True), "TeleChat2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Tele-AI/TeleChat2-3B"), trust_remote_code=True), @@ -330,8 +334,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] - "BertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5"), v0_only=True), - "Gemma2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-multilingual-gemma2"), v0_only=True), # noqa: E501 + "BertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5")), + "Gemma2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-multilingual-gemma2")), # noqa: E501 "GritLM": _HfExamplesInfo(os.path.join(models_path_prefix, "parasail-ai/GritLM-7B-vllm")), "GteModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Snowflake/snowflake-arctic-embed-m-v2.0"), trust_remote_code=True), @@ -344,9 +348,9 @@ _EMBEDDING_EXAMPLE_MODELS = { "LlamaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "llama"), is_available_online=False), "MistralModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/e5-mistral-7b-instruct")), "ModernBertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-modernbert-base"), - trust_remote_code=True, v0_only=True), + trust_remote_code=True), "NomicBertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "nomic-ai/nomic-embed-text-v2-moe"), - trust_remote_code=True, v0_only=True), # noqa: E501 + trust_remote_code=True), # noqa: E501 "Qwen2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "ssmits/Qwen2-7B-Instruct-embed-base")), "Qwen2ForRewardModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), max_transformers_version="4.53", @@ -354,9 +358,9 @@ _EMBEDDING_EXAMPLE_MODELS = { "Qwen2ForProcessRewardModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-PRM-7B"), max_transformers_version="4.53", transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 - "RobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/stsb-roberta-base-v2"), v0_only=True), # noqa: E501 - "RobertaForMaskedLM": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/all-roberta-large-v1"), v0_only=True), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/multilingual-e5-small"), v0_only=True), # noqa: E501 + "RobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/stsb-roberta-base-v2")), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/all-roberta-large-v1")), # noqa: E501 + "XLMRobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/multilingual-e5-small")), # noqa: E501 # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "royokong/e5-v")), "Phi3VForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "TIGER-Lab/VLM2Vec-Full"), @@ -371,16 +375,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { "GPT2ForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "nie3e/sentiment-polish-gpt2-small")), # noqa: E501 # [Cross-encoder] - "BertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2"), v0_only=True), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-reranker-modernbert-base"), v0_only=True), # noqa: E501 - "RobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/quora-roberta-base"), v0_only=True), # noqa: E501 - "XLMRobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-m3"), v0_only=True), # noqa: E501 + "BertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2")), # noqa: E501 + "GteNewForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-multilingual-reranker-base"), # noqa: E501 + trust_remote_code=True, + hf_overrides={ + "architectures": [os.path.join(models_path_prefix, "GteNewForSequenceClassification")]}),# noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "Alibaba-NLP/gte-reranker-modernbert-base")), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "cross-encoder/quora-roberta-base")), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-m3")), # noqa: E501 } _AUTOMATIC_CONVERTED_MODELS = { # Use as_seq_cls_model for automatic conversion "GemmaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-gemma"), # noqa: E501 - v0_only=True, hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501 "method": "no_post_processing"}), # noqa: E501 @@ -403,6 +410,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "Emu3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/Emu3-Chat-hf")), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-VL-28B-A3B-PT"), # noqa: E501 + trust_remote_code=True), "FuyuForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "adept/fuyu-8b")), "Gemma3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3-4b-it")), "Gemma3nForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3n-E2B-it"), # noqa: E501 @@ -413,23 +422,29 @@ _MULTIMODAL_EXAMPLE_MODELS = { hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")), # noqa: E501 "Glm4vMoeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.5V"), - is_available_online=False), # noqa: E501 + min_transformers_version="4.56"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo(os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-800m"), trust_remote_code=True, extras={"2b": os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-2b")}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible."), # noqa: E501 - "Idefics3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501 - {"tiny": os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM-256M-Instruct")}, # noqa: E501 - min_transformers_version="4.55.1", - transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501 - "InternVLChatModel": _HfExamplesInfo(os.path.join(models_path_prefix, "OpenGVLab/InternVL2-1B"), - extras={"2B": os.path.join(models_path_prefix, "OpenGVLab/InternVL2-2B"), - "3.0": os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B")}, # noqa: E501 - trust_remote_code=True), - "InternS1ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "internlm/Intern-S1"), + "HCXVisionForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix,"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"), # noqa: E501 + trust_remote_code=True), + "Idefics3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501 + {"tiny": os.path.join(models_path_prefix,"HuggingFaceTB/SmolVLM-256M-Instruct")}, # noqa: E501 + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55"), # noqa: E501 + "InternS1ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"internlm/Intern-S1"), + trust_remote_code=True), # noqa: E501 + "InternVLChatModel": _HfExamplesInfo(os.path.join(models_path_prefix,"OpenGVLab/InternVL2-1B"), + extras={"2B": os.path.join(models_path_prefix,"OpenGVLab/InternVL2-2B"), + "3.0": os.path.join(models_path_prefix,"OpenGVLab/InternVL3-1B"), # noqa: E501 + "3.5-qwen3": os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-1B"), # noqa: E501 + "3.5-qwen3moe": os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-30B-A3B"), # noqa: E501 + "3.5-gptoss": os.path.join(models_path_prefix,"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview")}, # noqa: E501 trust_remote_code=True), - "KeyeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "Kwai-Keye/Keye-VL-8B-Preview"), # noqa: E501 + "InternVLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"OpenGVLab/InternVL3-1B-hf")), # noqa: E501 + "KeyeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"Kwai-Keye/Keye-VL-8B-Preview"), # noqa: E501 trust_remote_code=True), "KimiVLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Instruct"), # noqa: E501 extras={"thinking": os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Thinking")}, # noqa: E501 @@ -451,7 +466,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "MiniCPMO": _HfExamplesInfo(os.path.join(models_path_prefix, "openbmb/MiniCPM-o-2_6"), trust_remote_code=True), "MiniCPMV": _HfExamplesInfo(os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"), - extras={"2.6": os.path.join(models_path_prefix, "openbmb/MiniCPM-V-2_6"), "4.0": os.path.join(models_path_prefix, "openbmb/MiniCPM-V-4")}, # noqa: E501 + extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-VL-01"), # noqa: E501 trust_remote_code=True, @@ -473,6 +488,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { transformers_version_reason="HF model is not compatible", # noqa: E501 extras={"1.6-llama": os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Llama3.2-3B"), "1.6-gemma": os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Gemma2-9B")}), # noqa: E501 + "Ovis2_5": _HfExamplesInfo(os.path.join(models_path_prefix, "AIDC-AI/Ovis2.5-2B"), + trust_remote_code=True), "PaliGemmaForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/paligemma-3b-mix-224"), # noqa: E501 extras={"v2": os.path.join(models_path_prefix, "google/paligemma2-3b-ft-docci-448")}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "microsoft/Phi-3-vision-128k-instruct"), @@ -496,26 +513,30 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-3B")), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-7B-AWQ")), # noqa: E501 + "RForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "YannQi/R-4B"), + trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B"), trust_remote_code=True), "SmolVLMForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 - min_transformers_version="4.55.1", - transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501 + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55"), # noqa: E501 "Step3VLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "stepfun-ai/step3"), - trust_remote_code=True, - is_available_online=False), + trust_remote_code=True), "UltravoxModel": _HfExamplesInfo(os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"), # noqa: E501 trust_remote_code=True), - "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501 - "Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501 + "TarsierForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "omni-research/Tarsier-7b")), # noqa: E501 + "Tarsier2ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "omni-research/Tarsier2-Recap-7b"), # noqa: E501 hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501 "VoxtralForConditionalGeneration": _HfExamplesInfo( - "mistralai/Voxtral-Mini-3B-2507", + os.path.join(models_path_prefix, "mistralai/Voxtral-Mini-3B-2507"), min_transformers_version="4.54", # disable this temporarily until we support HF format is_available_online=False, ), # [Encoder-decoder] + "DonutForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "naver-clova-ix/donut-base-finetuned-docvqa"), # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 + extras={"dolphin": os.path.join(models_path_prefix, "ByteDance/Dolphin")}), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix,"microsoft/Florence-2-base"), # noqa: E501 @@ -538,6 +559,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "DeepSeekMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "luccafong/deepseek_mtp_main_random"), speculative_model=os.path.join(models_path_prefix, "luccafong/deepseek_mtp_draft_random"), # noqa: E501 trust_remote_code=True), + "EagleDeepSeekMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "eagle618/deepseek-v3-random"), + speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 + trust_remote_code=True), "EagleLlamaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3-Instruct-8B"), trust_remote_code=True, speculative_model=os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3-Instruct-8B"), @@ -561,6 +585,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { is_available_online=False, speculative_model=os.path.join(models_path_prefix, "openbmb/MiniCPM-2B-sft-bf16"), tokenizer=os.path.join(models_path_prefix, "openbmb/MiniCPM-2B-sft-bf16")), + "ErnieMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT"), + trust_remote_code=True, + speculative_model=os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT")), "Glm4MoeMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.5"), speculative_model=os.path.join(models_path_prefix, "zai-org/GLM-4.5"), min_transformers_version="4.54", @@ -573,7 +600,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { _TRANSFORMERS_BACKEND_MODELS = { "TransformersModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-Embedding-0.6B")), "TransformersForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "hmellor/Ilama-3.2-1B"), trust_remote_code=True), # noqa: E501 - "TransformersForMultimodalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B-hf")), + "TransformersForMultimodalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/Emu3-Chat-hf")), } _EXAMPLE_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index f06b34285eaead63ba19c9dbbe5676f492fe9976..b4d516233b4bfc95e0f8b4678dae454f32a67c57 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -38,11 +38,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_arch=model_arch, exist_overrides=model_info.hf_overrides) - if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): - from vllm.model_executor.models.llama4 import Llama4ForCausalLM - from vllm.model_executor.models.registry import ModelRegistry - ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) - # Avoid calling model.forward() def _initialize_kv_caches_v0(self) -> None: self.cache_config.num_gpu_blocks = 0 @@ -95,6 +90,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): + if model_arch == "Lfm2ForCausalLM": + pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 077e026cf281a9e33db933ca901f31f5f8c8bda4..a19475a6a68fe56d9e80292b3d8218a3cd7d1542 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -27,6 +27,9 @@ models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_M @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): + # Skip if transformers version is incompatible + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_transformers_version(on_fail="skip") # Ensure all model classes can be imported successfully model_cls = ModelRegistry._try_load_model_cls(model_arch) assert model_cls is not None diff --git a/tests/models/utils.py b/tests/models/utils.py index 84aeb927c5fa96d8abc3542680ae2a0b3b2d569b..0fb1f5b3753b543ed2afe4b8832fec9b971e5d7c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -3,7 +3,8 @@ import warnings from collections.abc import Sequence -from typing import Any, NamedTuple, Optional, Union +from dataclasses import dataclass +from typing import Any, Optional, Union import torch import torch.nn.functional as F @@ -339,36 +340,43 @@ def softmax(data): return F.softmax(data, dim=-1) -class EmbedModelInfo(NamedTuple): +@dataclass +class ModelInfo: name: str - is_matryoshka: bool = False - matryoshka_dimensions: Optional[list[int]] = None architecture: str = "" dtype: str = "auto" + hf_overrides: Optional[dict[str, Any]] = None default_pooling_type: str = "" enable_test: bool = True +@dataclass +class EmbedModelInfo(ModelInfo): + is_matryoshka: bool = False + matryoshka_dimensions: Optional[list[int]] = None + + +@dataclass class CLSPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "LAST" -class RerankModelInfo(NamedTuple): - name: str - architecture: str = "" - dtype: str = "auto" - default_pooling_type: str = "" - enable_test: bool = True +@dataclass +class RerankModelInfo(ModelInfo): + pass +@dataclass class CLSPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "LAST" diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 2149f05b6af09329a8becf57ace471b67d1980e5..44c05db2278f7cbd5a2aa3c89170015f9a328892 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,32 +1,64 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np import pytest import torch -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.multimodal.cache import (MultiModalCache, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + processor_cache_from_config, + receiver_cache_from_config) +from vllm.multimodal.hasher import MultiModalHasher +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField) +from vllm.multimodal.processing import PromptInsertion +from vllm.multimodal.registry import MultiModalRegistry + +def _dummy_elem( + modality: str, + key: str, + size: int, + *, + rng: Optional[np.random.RandomState] = None, +): + if rng is None: + data = torch.empty((size, ), dtype=torch.int8) + else: + data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8)) -def _dummy_elem(modality: str, key: str, size: int): return MultiModalFieldElem( modality=modality, key=key, - data=torch.empty((size, ), dtype=torch.int8), + data=data, field=MultiModalSharedField(1), ) -def _dummy_item(modality: str, size_by_key: dict[str, int]): +def _dummy_item( + modality: str, + size_by_key: dict[str, int], + *, + rng: Optional[np.random.RandomState] = None, +): return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() + _dummy_elem(modality, key, size, rng=rng) + for key, size in size_by_key.items() ]) -def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): - return MultiModalKwargs([ - _dummy_item(modality, size_by_key) +def _dummy_items( + size_by_key_modality: dict[str, dict[str, int]], + *, + rng: Optional[np.random.RandomState] = None, +): + return MultiModalKwargsItems.from_seq([ + _dummy_item(modality, size_by_key, rng=rng) for modality, size_by_key in size_by_key_modality.items() ]) @@ -37,7 +69,8 @@ def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): [ (_dummy_item("a", {"a1": 100}), 100), (_dummy_item("a", {"a1": 100, "a2": 110}), 210), - (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 + (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 + (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501 ], ) # yapf: enable @@ -47,5 +80,139 @@ def test_cache_item_size(item, expected_size): cache[""] = item assert cache.currsize == expected_size - cache[""] = MultiModalCacheItemMetadata.wraps(item) + prompt_update = PromptInsertion("dummy", "target", "insertion") \ + .resolve(0) + + cache[""] = MultiModalProcessorCacheItem(item, [prompt_update]) + assert cache.currsize == expected_size + + cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) assert cache.currsize == expected_size + + +def _create_vllm_config( + *, + mm_processor_cache_gb: float, + enable_ipc: bool, +): + return VllmConfig( + model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb), + parallel_config=ParallelConfig( + data_parallel_size=1 if enable_ipc else 2), + ) + + +def _compare_caches( + config_0: VllmConfig, + config_1: VllmConfig, + *, + item_capacity: int = 8, + hit_rate: float = 0.5, + max_items_per_iter: int = 3, + is_cached_calls_per_iter: int, + n_iter: int = 100, + seed: int = 0, +): + mm_registry = MultiModalRegistry() + cache_0_p0 = processor_cache_from_config(config_0, mm_registry) + cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) + cache_1_p0 = processor_cache_from_config(config_1, mm_registry) + cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + + cache_size_gb = max( + config_0.model_config.mm_processor_cache_gb, + config_1.model_config.mm_processor_cache_gb, + ) + item_size_gb = int(cache_size_gb / item_capacity) + + rng = np.random.RandomState(seed) + all_items = [ + _dummy_item("item", {"key": item_size_gb}, rng=rng) + for _ in range(int(item_capacity / hit_rate)) + ] + all_hashes = [ + MultiModalHasher.hash_kwargs(item=item.get_data()) + for item in all_items + ] + + # Should not be used since there is nothing to convert to text + prompt_update = PromptInsertion("dummy", "target", "insertion") + + for it in range(n_iter): + num_items_to_select = rng.randint(0, max_items_per_iter) + item_idxs_to_select = rng.choice(len(all_items), num_items_to_select) + + selected_items = [all_items[idx] for idx in item_idxs_to_select] + selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select] + + if cache_0_p0 is None: + cache_0_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ + item for item, _ in cache_0_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_1_p0 is None: + cache_1_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ + item for item, _ in cache_1_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_0_p1 is None: + cache_0_p1_out = cache_0_p0_out + else: + cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, + selected_hashes) + + if cache_1_p1 is None: + cache_1_p1_out = cache_1_p0_out + else: + cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, + selected_hashes) + + assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}" + + +@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3]) +def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): + cache_size_gb = 1 / (1 << 20) + + vllm_config_ipc_enabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + vllm_config_ipc_disabled = _create_vllm_config( + mm_processor_cache_gb=0, + enable_ipc=False, + ) + vllm_config_cache_disabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + + _compare_caches( + vllm_config_ipc_enabled, + vllm_config_ipc_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_ipc_disabled, + vllm_config_cache_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_cache_disabled, + vllm_config_ipc_enabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py index 75a233c2567cbf5a6f5def4cd45411f097152802..2751e38760e1750f1a73476d759dca22facdf563 100644 --- a/tests/multimodal/test_hasher.py +++ b/tests/multimodal/test_hasher.py @@ -45,10 +45,11 @@ def test_hash_collision_image_transpose(): assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2) -def test_hash_collision_tensor_shape(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_hash_collision_tensor_shape(dtype): # The hash should be different though the data is the same when flattened - arr1 = torch.zeros((5, 10, 20, 3)) - arr2 = torch.zeros((10, 20, 5, 3)) + arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype) + arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype) hasher = MultiModalHasher assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index cb489c47fd8fda5b01dca461d84f8f5790727767..6ce5fcfe644bd6a83b874e5c33c1129b44b96dcd 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo, PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, - find_text_matches, find_token_matches, iter_token_matches, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import full_groupby from .utils import random_image @@ -75,12 +73,15 @@ from .utils import random_image ), ], ) +@pytest.mark.parametrize("start_idx", [0, 4, 8]) # yapf: enable -def test_iter_token_matches(token_ids, match_ids, expected): - result = list(iter_token_matches(token_ids, match_ids)) +def test_iter_token_matches(token_ids, match_ids, expected, start_idx): + result = list(iter_token_matches(token_ids, match_ids, + start_idx=start_idx)) # Manually constructed results - assert [item._asdict() for item in result] == expected + assert [item._asdict() for item in result + ] == [item for item in expected if item["start_idx"] >= start_idx] # Invariants match_lens = [end - start for start, end in result] @@ -241,21 +242,23 @@ def test_find_token_matches( # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_token_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_token_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -388,21 +391,23 @@ def test_find_text_matches( # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_text_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_text_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -552,39 +557,35 @@ def test_find_update_text( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_text_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_text_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_text_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ - # Tokenized test cases of `test_find_replace_text` + # Tokenized test cases of `test_find_update_text` # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf ( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], @@ -726,32 +727,28 @@ def test_find_update_tokens( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_token_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_token_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_token_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @@ -878,17 +875,11 @@ def test_find_mm_placeholders( mock_tokenizer = cast(AnyTokenizer, object()) mm_prompt_updates = { - key: [update_type(key, [], repl).bind(mock_tokenizer)] + key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] for key, repl in repl_by_key.items() } - result = find_mm_placeholders( - mm_prompt_updates, - prompt, - # Effectively match all occurrences in the prompt - {key: 3 - for key in repl_by_key}, - ) + result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) # Only displayed on error print("result:", result) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 10241b5c2ba0813dbeb348b514c4bab39b0c2f0b..370dd0f3d405f36eb04e2d2a4f0be377f4469af3 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import math import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -21,6 +22,8 @@ from vllm.distributed.parallel_state import (init_distributed_environment, from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, + get_load_balance_assignment, + run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model) from vllm.platforms import current_platform from vllm.utils import get_open_port, update_environment_variables @@ -427,8 +430,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Set random seed for reproducibility current_platform.seed_everything(0) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) torch.set_default_device(device) update_environment_variables({ @@ -465,3 +468,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Check that the outputs are close (they should be identical) assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," + "expected_grouped_sizes_per_gpu,test_description", + [ + # Empty input + ([], 2, [], [0, 0], [0, 0], "empty input"), + + # Fewer samples than GPUs + ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 + ], "fewer samples than GPUs"), + + # Single GPU + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), + + # Balanced assignment + ([100, 100, 100, 100 + ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), + + # Unbalanced sizes - this one is trickier since the algorithm is greedy + ([1000, 100, 200, 50], 2, [0, 2, 1, 3 + ], [1, 3], [1000, 350], "unbalanced sizes"), + ], +) +def test_get_load_balance_assignment_cases(sizes, num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description): + """Test get_load_balance_assignment with various input cases.""" + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result + + # Common assertions for all cases + assert len(shuffle_indices) == len(sizes) + assert len(gpu_sample_counts) == num_gpus + assert len(grouped_sizes_per_gpu) == num_gpus + assert sum(gpu_sample_counts) == len(sizes) + + assert shuffle_indices == expected_shuffle_indices + + assert gpu_sample_counts == expected_gpu_sample_counts + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu + + +class SimpleMRopeVisionModel(torch.nn.Module): + """A simple vision model for testing mrope functionality.""" + + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.linear = torch.nn.Linear(768, out_hidden_size) + + def forward(self, pixel_values: torch.Tensor, + grid_thw_list: list[list[int]]): + """Simple forward pass that simulates spatial merging.""" + # Apply linear transformation + embeddings = self.linear(pixel_values) + + # Simulate spatial merging by reducing the number of patches + merge_factor = self.spatial_merge_size * self.spatial_merge_size + + # Group patches and merge spatially + merged_embeddings = [] + start_idx = 0 + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + end_idx = start_idx + num_patches + + # Get patches for this image + image_patches = embeddings[start_idx:end_idx] + + # Simulate spatial merging by averaging groups of patches + merged_patches = num_patches // merge_factor + if merged_patches > 0: + # Reshape and average to simulate merging + reshaped = image_patches[:merged_patches * merge_factor].view( + merged_patches, merge_factor, -1) + merged = reshaped.mean(dim=1) + merged_embeddings.append(merged) + + start_idx = end_idx + + if merged_embeddings: + return torch.cat(merged_embeddings, dim=0) + else: + return torch.empty((0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 3, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_mrope_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_mrope_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, + world_size: int, + batch_size: int, + master_port: int): + """ + Test that run_dp_sharded_mrope_vision_model produces the same results as + calling the model directly. + """ + # Set random seed for reproducibility + current_platform.seed_everything(0) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test data + grid_thw_list = [] + pixel_values_list = [] + + for i in range(batch_size): + # Varying image sizes for better testing + t, h, w = 1, 4 + i, 4 + i + grid_thw_list.append([t, h, w]) + + num_patches = t * h * w + # Create random pixel values for this image + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + # Concatenate all pixel values + pixel_values = torch.cat(pixel_values_list, dim=0) + + # Create a simple mrope vision model + vision_model = SimpleMRopeVisionModel() + + # Run the model directly on the full input (only on rank 0) + if local_rank == 0: + with torch.inference_mode(): + direct_output = vision_model(pixel_values, grid_thw_list) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list) + sharded_output = torch.cat(sharded_output, dim=0) + + # Check that the world size is setup correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Compare outputs (only on rank 0) + if local_rank == 0: + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, + sharded_output, + rtol=1e-5, + atol=1e-5) + + +@multi_gpu_test(num_gpus=2) +def test_run_dp_sharded_mrope_vision_model_empty_input(): + world_size = 2 + mp.spawn( + run_dp_sharded_mrope_vision_model_empty_input_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_empty_input_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with empty input.""" + # Set up distributed environment + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create empty inputs + pixel_values = torch.empty((0, 768)) + grid_thw_list: list[list[int]] = [] + + vision_model = SimpleMRopeVisionModel() + + # Should handle empty input gracefully + with torch.inference_mode(): + output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, + grid_thw_list) + + assert len(output) == 0 + + +@multi_gpu_test(num_gpus=4) +def test_run_dp_sharded_mrope_vision_model_uneven_load(): + world_size = 4 + mp.spawn( + run_dp_sharded_mrope_vision_model_uneven_load_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_uneven_load_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" + # Set up distributed environment + current_platform.seed_everything(123) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create images with very different sizes + grid_thw_list = [ + [1, 2, 2], # Small: 4 patches + [1, 8, 8], # Large: 64 patches + [1, 3, 3], # Medium: 9 patches + ] + + pixel_values_list = [] + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel() + + # Should handle uneven distribution without errors + with torch.inference_mode(): + output_tuple = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list) + + # Verify output shape is reasonable + merge_factor = vision_model.spatial_merge_size**2 + expected_output_patches = list( + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) + + for i, output in enumerate(output_tuple): + assert output.shape[0] == expected_output_patches[i] + assert output.shape[1] == vision_model.out_hidden_size + + +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" + device = current_platform.device_type + + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images + pixel_values_list = [] + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768, device=device) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel( + spatial_merge_size=spatial_merge_size).to(device) + + with torch.inference_mode(): + output = vision_model(pixel_values, grid_thw_list) + + # Verify output dimensions based on spatial merging + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) + merge_factor = spatial_merge_size**2 + expected_output_patches = total_patches // merge_factor + + assert output.shape[0] == expected_output_patches + assert output.shape[1] == vision_model.out_hidden_size diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a750c756c11a26f43799db3066efdfeda0b95d1a --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +def register_prithvi_india(): + return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501 + + +def register_prithvi_valencia(): + return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501 diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d49a50b7a309f33d7b8940010e027793e1d94460 --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -0,0 +1,449 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import base64 +import datetime +import os +import tempfile +import urllib.request +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Optional, Union + +import albumentations +import numpy as np +import rasterio +import regex as re +import torch +from einops import rearrange +from terratorch.datamodules import Sen1Floods11NonGeoDataModule + +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import (IOProcessorRequest, + IOProcessorResponse) +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput +from vllm.plugins.io_processors.interface import (IOProcessor, + IOProcessorInput, + IOProcessorOutput) + +from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput + +logger = init_logger(__name__) + +NO_DATA = -9999 +NO_DATA_FLOAT = 0.0001 +OFFSET = 0 +PERCENTILE = 99 + +DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5] + +datamodule_config: DataModuleConfig = { + "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], + "batch_size": + 16, + "constant_scale": + 0.0001, + "data_root": + "/dccstor/geofm-finetuning/datasets/sen1floods11", + "drop_last": + True, + "no_data_replace": + 0.0, + "no_label_replace": + -1, + "num_workers": + 8, + "test_transform": [ + albumentations.Resize(always_apply=False, + height=448, + interpolation=1, + p=1, + width=448), + albumentations.pytorch.ToTensorV2(transpose_mask=False, + always_apply=True, + p=1.0), + ], +} + + +def save_geotiff(image: torch.Tensor, meta: dict, + out_format: str) -> str | bytes: + """Save multi-band image in Geotiff file. + + Args: + image: np.ndarray with shape (bands, height, width) + output_path: path where to save the image + meta: dict with meta info. + """ + if out_format == "path": + # create temp file + file_path = os.path.join(os.getcwd(), "prediction.tiff") + with rasterio.open(file_path, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + return file_path + elif out_format == "b64_json": + with tempfile.NamedTemporaryFile() as tmpfile: + with rasterio.open(tmpfile.name, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + file_data = tmpfile.read() + return base64.b64encode(file_data) + + else: + raise ValueError("Unknown output format") + + +def _convert_np_uint8(float_image: torch.Tensor): + image = float_image.numpy() * 255.0 + image = image.astype(dtype=np.uint8) + + return image + + +def read_geotiff( + file_path: Optional[str] = None, + path_type: Optional[str] = None, + file_data: Optional[bytes] = None, +) -> tuple[torch.Tensor, dict, tuple[float, float] | None]: + """Read all bands from *file_path* and return image + meta info. + + Args: + file_path: path to image file. + + Returns: + np.ndarray with shape (bands, height, width) + meta info dict + """ + + if all([x is None for x in [file_path, path_type, file_data]]): + raise Exception("All input fields to read_geotiff are None") + write_to_file: Optional[bytes] = None + path: Optional[str] = None + if file_data is not None: + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(file_data) + # path = tmpfile.name + + write_to_file = file_data + elif file_path is not None and path_type == "url": + resp = urllib.request.urlopen(file_path) + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(resp.read()) + # path = tmpfile.name + write_to_file = resp.read() + elif file_path is not None and path_type == "path": + path = file_path + elif file_path is not None and path_type == "b64_json": + image_data = base64.b64decode(file_path) + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(image_data) + # path = tmpfile.name + write_to_file = image_data + else: + raise Exception("Wrong combination of parameters to read_geotiff") + + with tempfile.NamedTemporaryFile() as tmpfile: + path_to_use = None + if write_to_file: + tmpfile.write(write_to_file) + path_to_use = tmpfile.name + elif path: + path_to_use = path + + with rasterio.open(path_to_use) as src: + img = src.read() + meta = src.meta + try: + coords = src.lnglat() + except Exception: + # Cannot read coords + coords = None + + return img, meta, coords + + +def load_image( + data: Union[list[str]], + path_type: str, + mean: Optional[list[float]] = None, + std: Optional[list[float]] = None, + indices: Optional[Union[list[int], None]] = None, +): + """Build an input example by loading images in *file_paths*. + + Args: + file_paths: list of file paths . + mean: list containing mean values for each band in the + images in *file_paths*. + std: list containing std values for each band in the + images in *file_paths*. + + Returns: + np.array containing created example + list of meta info for each image in *file_paths* + """ + + imgs = [] + metas = [] + temporal_coords = [] + location_coords = [] + + for file in data: + # if isinstance(file, bytes): + # img, meta, coords = read_geotiff(file_data=file) + # else: + img, meta, coords = read_geotiff(file_path=file, path_type=path_type) + # Rescaling (don't normalize on nodata) + img = np.moveaxis(img, 0, -1) # channels last for rescaling + if indices is not None: + img = img[..., indices] + if mean is not None and std is not None: + img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std) + + imgs.append(img) + metas.append(meta) + if coords is not None: + location_coords.append(coords) + + try: + match = re.search(r"(\d{7,8}T\d{6})", file) + if match: + year = int(match.group(1)[:4]) + julian_day = match.group(1).split("T")[0][4:] + if len(julian_day) == 3: + julian_day = int(julian_day) + else: + julian_day = (datetime.datetime.strptime( + julian_day, "%m%d").timetuple().tm_yday) + temporal_coords.append([year, julian_day]) + except Exception: + logger.exception("Could not extract timestamp for %s", file) + + imgs = np.stack(imgs, axis=0) # num_frames, H, W, C + imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W + imgs = np.expand_dims(imgs, axis=0) # add batch di + + return imgs, temporal_coords, location_coords, metas + + +class PrithviMultimodalDataProcessor(IOProcessor): + + def __init__(self, vllm_config: VllmConfig): + + super().__init__(vllm_config) + + self.datamodule = Sen1Floods11NonGeoDataModule( + data_root=datamodule_config["data_root"], + batch_size=datamodule_config["batch_size"], + num_workers=datamodule_config["num_workers"], + bands=datamodule_config["bands"], + drop_last=datamodule_config["drop_last"], + test_transform=datamodule_config["test_transform"], + ) + self.img_size = 512 + self.h1 = 1 + self.w1 = 1 + self.original_h = 512 + self.original_w = 512 + self.batch_size = 1 + self.meta_data = None + self.requests_cache: dict[str, dict[str, Any]] = {} + self.indices = DEFAULT_INPUT_INDICES + + def parse_request(self, request: Any) -> IOProcessorInput: + if type(request) is dict: + image_prompt = ImagePrompt(**request) + return image_prompt + if isinstance(request, IOProcessorRequest): + if not hasattr(request, "data"): + raise ValueError( + "missing 'data' field in OpenAIBaseModel Request") + + request_data = request.data + + if type(request_data) is dict: + return ImagePrompt(**request_data) + else: + raise ValueError("Unable to parse the request data") + + raise ValueError("Unable to parse request") + + def output_to_response( + self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + return IOProcessorResponse( + request_id=plugin_output.request_id, + data=plugin_output, + ) + + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + + image_data = dict(prompt) + + if request_id: + self.requests_cache[request_id] = { + "out_format": image_data["out_data_format"], + } + + input_data, temporal_coords, location_coords, meta_data = load_image( + data=[image_data["data"]], + indices=self.indices, + path_type=image_data["data_format"], + ) + + self.meta_data = meta_data[0] + + if input_data.mean() > 1: + input_data = input_data / 10000 # Convert to range 0-1 + + self.original_h, self.original_w = input_data.shape[-2:] + pad_h = (self.img_size - + (self.original_h % self.img_size)) % self.img_size + pad_w = (self.img_size - + (self.original_w % self.img_size)) % self.img_size + input_data = np.pad( + input_data, + ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), + mode="reflect", + ) + + batch = torch.tensor(input_data) + windows = batch.unfold(3, self.img_size, + self.img_size).unfold(4, self.img_size, + self.img_size) + self.h1, self.w1 = windows.shape[3:5] + windows = rearrange( + windows, + "b c t h1 w1 h w -> (b h1 w1) c t h w", + h=self.img_size, + w=self.img_size, + ) + + # Split into batches if number of windows > batch_size + num_batches = (windows.shape[0] // self.batch_size + if windows.shape[0] > self.batch_size else 1) + windows = torch.tensor_split(windows, num_batches, dim=0) + + if temporal_coords: + temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) + else: + temporal_coords = None + if location_coords: + location_coords = torch.tensor(location_coords[0]).unsqueeze(0) + else: + location_coords = None + + prompts = [] + for window in windows: + # Apply standardization + window = self.datamodule.test_transform( + image=window.squeeze().numpy().transpose(1, 2, 0)) + window = self.datamodule.aug(window)["image"] + prompts.append({ + "prompt_token_ids": [1], + "multi_modal_data": { + "pixel_values": window.to(torch.float16)[0], + "location_coords": location_coords.to(torch.float16), + }, + }) + + return prompts + + async def pre_process_async( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + return self.pre_process(prompt, request_id, **kwargs) + + def post_process( + self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + + pred_imgs_list = [] + + if request_id and (request_id in self.requests_cache): + out_format = self.requests_cache[request_id]["out_format"] + else: + out_format = "b64_json" + + for output in model_output: + y_hat = output.outputs.data.argmax(dim=1) + pred = torch.nn.functional.interpolate( + y_hat.unsqueeze(1).float(), + size=self.img_size, + mode="nearest", + ) + pred_imgs_list.append(pred) + + pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0) + + # Build images from patches + pred_imgs = rearrange( + pred_imgs, + "(b h1 w1) c h w -> b c (h1 h) (w1 w)", + h=self.img_size, + w=self.img_size, + b=1, + c=1, + h1=self.h1, + w1=self.w1, + ) + + # Cut padded area back to original size + pred_imgs = pred_imgs[..., :self.original_h, :self.original_w] + + # Squeeze (batch size 1) + pred_imgs = pred_imgs[0] + + if not self.meta_data: + raise ValueError("No metadata available for the current task") + self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) + out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data, + out_format) + + return ImageRequestOutput(type=out_format, + format="tiff", + data=out_data, + request_id=request_id) + + async def post_process_async( + self, + model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + collected_output = [item async for i, item in model_output] + return self.post_process(collected_output, request_id, **kwargs) + + +class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor): + + def __init__(self, vllm_config: VllmConfig): + + super().__init__(vllm_config) + + self.indices = [1, 2, 3, 8, 11, 12] + + +class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor): + + def __init__(self, vllm_config: VllmConfig): + + super().__init__(vllm_config) + + self.indices = [0, 1, 2, 3, 4, 5] diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py new file mode 100644 index 0000000000000000000000000000000000000000..d480aef704c6155b68ce2022b9fcd532bc3b6c07 --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Literal, Optional, TypedDict, Union + +import albumentations +from pydantic import BaseModel + + +class DataModuleConfig(TypedDict): + bands: list[str] + batch_size: int + constant_scale: float + data_root: str + drop_last: bool + no_data_replace: float + no_label_replace: int + num_workers: int + test_transform: list[ + albumentations.core.transforms_interface.BasicTransform] + + +class ImagePrompt(BaseModel): + + data_format: Literal["b64_json", "bytes", "url"] + """ + This is the data type for the input image + """ + + image_format: str + """ + This is the image format (e.g., jpeg, png, etc.) + """ + + out_data_format: Literal["b64_json", "url"] + + data: Any + """ + Input image data + """ + + +MultiModalPromptType = Union[ImagePrompt] + + +class ImageRequestOutput(BaseModel): + """ + The output data of an image request to vLLM. + + Args: + type (str): The data content type [path, object] + format (str): The image format (e.g., jpeg, png, etc.) + data (Any): The resulting data. + """ + + type: Literal["path", "b64_json"] + format: str + data: str + request_id: Optional[str] = None diff --git a/tests/plugins/prithvi_io_processor_plugin/setup.py b/tests/plugins/prithvi_io_processor_plugin/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a03b1fbbd4a809821c8443f601a5b774461ee0ae --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/setup.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from setuptools import setup + +setup( + name="prithvi_io_processor_plugin", + version="0.1", + packages=["prithvi_io_processor"], + entry_points={ + "vllm.io_processor_plugins": [ + "prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501 + "prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501 + ] + }, +) diff --git a/tests/plugins_tests/conftest.py b/tests/plugins_tests/conftest.py deleted file mode 100644 index c8c1b81ca21837db7e182ef1093f378a4f632c76..0000000000000000000000000000000000000000 --- a/tests/plugins_tests/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') \ No newline at end of file diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py new file mode 100644 index 0000000000000000000000000000000000000000..00fe429445d7dc8550dcfce76fa833ff71d90e1e --- /dev/null +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.config import VllmConfig +from vllm.entrypoints.llm import LLM +from vllm.entrypoints.openai.protocol import IOProcessorResponse +from vllm.plugins.io_processors import get_io_processor +from vllm.pooling_params import PoolingParams + +MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" + +image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + + +def test_loading_missing_plugin(): + vllm_config = VllmConfig() + with pytest.raises(ValueError): + get_io_processor(vllm_config, "wrong_plugin") + + +def test_loading_engine_with_wrong_plugin(): + + with pytest.raises(ValueError): + LLM( + model=MODEL_NAME, + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + max_num_seqs=32, + io_processor_plugin="wrong_plugin", + ) + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + + with vllm_runner( + model_name, + runner="pooling", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + max_num_seqs=1, + io_processor_plugin="prithvi_to_tiff_valencia", + ) as llm_runner: + pooler_output = llm_runner.get_llm().encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + # verify the output is formatted as expected for this plugin + assert all( + hasattr(output, attr) + for attr in ["type", "format", "data", "request_id"]) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(output.data) + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--runner", + "pooling", + "--enforce-eager", + "--trust-remote-code", + "--skip-tokenizer-init", + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + "--max-num-seqs", + "32", + "--io-processor-plugin", + "prithvi_to_tiff_valencia" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_prithvi_mae_plugin_online( + server: RemoteOpenAIServer, + model_name: str, +): + + request_payload_url = { + "data": { + "data": image_url, + "data_format": "url", + "image_format": "tiff", + "out_data_format": "b64_json", + }, + "priority": 0, + "model": model_name, + } + + ret = requests.post( + server.url_for("pooling"), + json=request_payload_url, + ) + + response = ret.json() + + # verify the request response is in the correct format + assert (parsed_response := IOProcessorResponse(**response)) + + # verify the output is formatted as expected for this plugin + plugin_data = parsed_response.data + + assert all( + plugin_data.get(attr) + for attr in ["type", "format", "data", "request_id"]) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(plugin_data["data"]) diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index b7f1090ca03972ff838f724cc8ba2cedacaf4bf3..698d0a975e4916ebe2bb5676a7a7e1d02fcf7370 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -7,6 +7,15 @@ import torch from vllm.plugins import load_general_plugins +@pytest.fixture(scope="function", autouse=True) +def use_v0_only(monkeypatch): + """ + Since this module is V0 only, set VLLM_USE_V1=0 for + all tests in the module. + """ + monkeypatch.setenv('VLLM_USE_V1', '0') + + def test_platform_plugins(): # simulate workload by running an example import runpy diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 8923cf35acca8eafd265b54d4d4b0d12e174b447..8ba91620538726259b13341bc3842ae59145dec1 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -16,10 +16,10 @@ from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -662,6 +662,7 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): assert output + # @pytest.mark.parametrize( # "args", # [("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", @@ -690,3 +691,61 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): # output = llm.generate_greedy("Hello my name is", max_tokens=20) # print(output) # assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), + reason="W4A8 FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize("args", [ + ("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8) +]) +def test_compressed_tensors_w4a8_fp8(vllm_runner, args): + model, scheme = args + with vllm_runner(model, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): + assert isinstance(proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(proj.scheme, scheme) + + assert proj.weight_packed.dtype is torch.int32 + assert proj.weight_scale.dtype is torch.float8_e4m3fn + assert proj.weight_chan_scale.dtype is torch.float32 + assert proj.scheme.group_size == 128 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") +@pytest.mark.parametrize("model,prompt,exp_perplexity", [ + ( + "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), + ( + "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), +]) +def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, + exp_perplexity): + with vllm_runner(model, enforce_eager=True) as llm: + perplexity = llm.generate_prompt_perplexity([prompt])[0] + print(perplexity) + assert perplexity <= exp_perplexity diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index 578bcbfe6174fdc048303caa5761fa831d6ba97b..25d0642b330d8c26a86852f87937a4a21f18c157 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -24,22 +24,13 @@ class ModelPair: MODEL_ARG_EXPTYPES = [ # AUTOGPTQ # compat: autogptq <=0.7.1 is_marlin_format: bool - # Model Serialized in Marlin Format should always use Marlin kernel. - # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), None, "marlin"), - # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "marlin", "marlin"), - # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "gptq", "marlin"), - (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "awq", "ERROR"), # Model Serialized in Exllama Format. # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), None, "gptq_marlin"), # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "marlin", "gptq_marlin"), # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "gptq", "gptq"), (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "awq", "ERROR"), # compat: autogptq >=0.8.0 use checkpoint_format: str - # Model Serialized in Marlin Format should always use Marlin kernel. - # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), None, "marlin"), - # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "marlin", "marlin"), - # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "gptq", "marlin"), - (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "awq", "ERROR"), + # Model Serialized in Exllama Format. # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), None, "gptq_marlin"), # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "marlin", "gptq_marlin"), diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index 8e79ec954faa418c887935250a1d24abcd9da381..2d1cf4c5a2d6181b81069002830fb6b1049a77ff 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -12,7 +12,6 @@ import os from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod) -from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod) from ..utils import models_path_prefix @@ -21,9 +20,7 @@ PROMPT = "On the surface of Mars, we found" MODELS_QUANT = [ (os.path.join(models_path_prefix, "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head"), True), - (os.path.join(models_path_prefix, "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024"), False), - # (os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"), False), - # (os.path.join(models_path_prefix, "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"), False) + (os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"), False), ] @@ -43,8 +40,7 @@ def test_lm_head( lm_head_layer = model.lm_head if lm_head_quantized: assert isinstance(lm_head_layer.quant_method, - (GPTQLinearMethod, GPTQMarlinLinearMethod, - MarlinLinearMethod)) + (GPTQLinearMethod, GPTQMarlinLinearMethod)) else: assert isinstance(lm_head_layer.quant_method, UnquantizedEmbeddingMethod) @@ -52,5 +48,5 @@ def test_lm_head( vllm_model.apply_model(check_model) print( - vllm_model.generate_greedy(prompts=["Hello my name is"], - max_tokens=10)[0][1]) \ No newline at end of file + vllm_model.generate_greedy(["Hello my name is"], + max_tokens=10)[0][1]) diff --git a/tests/quantization/untest_fp8.py b/tests/quantization/untest_fp8.py index 9d7e0e85ab92f2a0b732d65e4f9001a71b5feead..1debc12fe901567352aded2d09b5fd10400987ed 100644 --- a/tests/quantization/untest_fp8.py +++ b/tests/quantization/untest_fp8.py @@ -40,8 +40,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) @@ -92,8 +91,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 20ba12c0ecd9a4bc1efaa854580ad1d845286c43..8e0364189298df8541ffc24a79442c731ba32713 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -69,6 +69,59 @@ def test_beam_search_single_input( f"vLLM: {vllm_output_ids}") +@pytest.mark.skip_v1 # FIXME: This fails on V1 right now. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) +def test_beam_search_with_concurrency_limit( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + # example_prompts[1]&[3]&[7] fails due to unknown reason even without + # concurency limit. skip them for now. + example_prompts = (example_prompts[:8]) + concurrency_limit = 2 + assert len(example_prompts) > concurrency_limit + with vllm_runner(model, dtype=dtype) as vllm_model: + outputs_with_limit = vllm_model.generate_beam_search( + example_prompts, + beam_width, + max_tokens, + concurrency_limit=concurrency_limit) + outputs_without_limit = [] + + for i in range(0, len(example_prompts), concurrency_limit): + outputs_without_limit.extend( + vllm_model.generate_beam_search( + example_prompts[i:i + concurrency_limit], beam_width, + max_tokens)) + + correct = True + for i in range(len(example_prompts)): + output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i] + output_ids_without_limit, output_texts_without_limit = ( + outputs_without_limit[i]) + for j, (text_with_limit, text_without_limit) in enumerate( + zip(output_texts_with_limit, output_texts_without_limit)): + print(f">>>{j}-th with limit output:") + print(text_with_limit) + print(f">>>{j}-th without limit output:") + print(text_without_limit) + assert len(output_ids_with_limit) == len(output_ids_without_limit) + for j in range(len(output_ids_with_limit)): + if output_ids_with_limit[j] != output_ids_without_limit[j]: + print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n" + f"-limit: {output_ids_without_limit}") + correct = False + assert correct + + @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index c734c8514a6da53e65bd2614bfcb44e4bb20471a..1b019be9e56dc8541a8e51a471c6e8ab6121a62f 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -98,3 +99,38 @@ def test_sequence_group_stage(): assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(1) assert seq_group.is_prefill() is False + + +def test_sequence_intermediate_tensors_equal(): + + class AnotherIntermediateTensors(IntermediateTensors): + pass + + intermediate_tensors = IntermediateTensors({}) + another_intermediate_tensors = AnotherIntermediateTensors({}) + assert intermediate_tensors != another_intermediate_tensors + + empty_intermediate_tensors_1 = IntermediateTensors({}) + empty_intermediate_tensors_2 = IntermediateTensors({}) + assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 + + different_key_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + difference_key_intermediate_tensors_2 = IntermediateTensors( + {"2": torch.zeros([2, 4], dtype=torch.int32)}) + assert (different_key_intermediate_tensors_1 + != difference_key_intermediate_tensors_2) + + same_key_different_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_different_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 5], dtype=torch.int32)}) + assert (same_key_different_value_intermediate_tensors_1 + != same_key_different_value_intermediate_tensors_2) + + same_key_same_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_same_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + assert (same_key_same_value_intermediate_tensors_1 == + same_key_same_value_intermediate_tensors_2) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index d4fc7a3d3a12edcdb8ec04926d4326c4a3112c13..9c50707ae55bf76c58b2a5b1a4855b0835bcd64f 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -67,8 +67,6 @@ def _run_incremental_decode(tokenizer, request = EngineCoreRequest("", prompt_token_ids, None, - None, - None, params, None, None, diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index 40c3158e9e683f118c4c1cc38d370a799a3959bd..ccb2acf512caf584ee2b96bbef7dab349691ad63 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -16,7 +16,7 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" +MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" @pytest.fixture(scope="module") @@ -397,7 +397,9 @@ hello world "no_tools", "single_tool", "single_tool_with_content", + "single_tool_multiline_param", "parallel_tools", + "tool_with_typed_params", # Added this test case ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -422,7 +424,7 @@ fahrenheit "state": "TX", "unit": "fahrenheit" }))) - ], ""), + ], None), ('''Sure! Let me check the weather for you.<tool_call> <function=get_current_weather> <parameter=city> @@ -445,6 +447,30 @@ fahrenheit }))) ], "Sure! Let me check the weather for you."), ('''<tool_call> +<function=calculate_area> +<parameter=shape> +rectangle +</parameter> +<parameter=dimensions> +{"width": 10, + "height": 20} +</parameter> +<parameter=precision> +2 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "rectangle", + "dimensions": { + "width": 10, + "height": 20 + }, + "precision": 2 + }))) + ], None), + ('''<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -484,13 +510,36 @@ celsius "state": "FL", "unit": "celsius" }))) - ], ""), + ], None), + # Added tool_with_typed_params test case + ('''Let me calculate that area for you.<tool_call> +<function=calculate_area> +<parameter=shape> +circle +</parameter> +<parameter=dimensions> +{"radius": 15.5} +</parameter> +<parameter=precision> +3 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "circle", + "dimensions": { + "radius": 15.5 + }, + "precision": 3 + }))) + ], "Let me calculate that area for you."), ], ) def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, sample_tools, model_output, expected_tool_calls, expected_content): - """Test incremental streaming behavior""" + """Test incremental streaming behavior including typed parameters""" request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) @@ -539,7 +588,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, "arguments"] += tool_call.function.arguments # Verify final content - assert other_content == expected_content + assert other_content == (expected_content or "") # Handle None case # Verify we got all expected tool calls assert len(tool_states) == len(expected_tool_calls) @@ -559,6 +608,125 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, assert actual_args == expected_args +def test_extract_tool_calls_missing_closing_parameter_tag( + qwen3_tool_parser, sample_tools): + """Test handling of missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + # The parser should handle the malformed XML gracefully + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + + # Verify the function name is correct + assert extracted_tool_calls.tool_calls[ + 0].function.name == "get_current_weather" + + # Verify the arguments are parsed despite the missing closing tag + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert "city" in args + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + # Check that content before the tool call is preserved + assert "Let me check the weather for you:" in extracted_tool_calls.content + + +def test_extract_tool_calls_streaming_missing_closing_tag( + qwen3_tool_parser, qwen3_tokenizer, sample_tools): + """Test streaming with missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify content was streamed + assert "Let me check the weather for you:" in other_content + + # Verify we got the tool call + assert len(tool_states) == 1 + state = tool_states[0] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == "get_current_weather" + + # Verify arguments were parsed correctly despite missing closing tag + assert state["arguments"] is not None + args = json.loads(state["arguments"]) + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, qwen3_tokenizer, sample_tools): diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c276a598aa68c0304e8415da5878eb184417e5f9 --- /dev/null +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from collections.abc import Generator +from typing import Optional + +import pytest + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +# Use a common model that is likely to be available +MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" + + +@pytest.fixture(scope="module") +def seed_oss_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def seed_oss_tool_parser(seed_oss_tokenizer): + return SeedOssToolParser(seed_oss_tokenizer) + + +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_weather", + "description": "Get current temperature for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": + "City and country e.g. Bogotá, Colombia" + }, + "unit": { + "type": "string", + "description": "this is the unit of temperature" + } + }, + "required": ["location"], + "additionalProperties": False + }, + "returns": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "description": "temperature in celsius" + } + }, + "required": ["temperature"], + "additionalProperties": False + }, + "strict": True + }), + ] + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + # Seed-OSS tool call will not generate id + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + assert actual_tool_call.function.name == expected_tool_call.function.name + assert actual_tool_call.function.arguments == expected_tool_call.function.arguments + + +def test_extract_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], None), + ( + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" + """\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""", + ), + ( + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" + """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think>""", + ), + ], +) +def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output, + expected_tool_calls, expected_content): + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=request) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_streaming_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + + result = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text="his is a test response", + current_text=model_output, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." + + +def stream_delta_message_generator( + seed_oss_tool_parser: SeedOssToolParser, + seed_oss_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None +) -> Generator[DeltaMessage, None, None]: + all_token_ids = seed_oss_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=seed_oss_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + ), + ( + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" + """\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""", + ), + ( + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" + """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think>""", + ), + ], +) +def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, + sample_tools, model_output, expected_tool_calls, + expected_content): + """Test incremental streaming behavior""" + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} # Track state per tool index + + for delta_message in stream_delta_message_generator( + seed_oss_tool_parser, seed_oss_tokenizer, model_output, request): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + # First chunk should have id, name, and type + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + # Should only be set once + assert tool_states[idx]["name"] is None + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + # Accumulate arguments incrementally + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify final content + assert other_content == expected_content + + # Verify we got all expected tool calls + assert len(tool_states) == len(expected_tool_calls) + + # Verify each tool call + for idx, expected_tool in enumerate(expected_tool_calls): + state = tool_states[idx] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == expected_tool.function.name + + # Parse accumulated arguments + arguments_str = state["arguments"] + assert arguments_str is not None + actual_args = json.loads(arguments_str) + expected_args = json.loads(expected_tool.function.arguments) + assert actual_args == expected_args diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 17bc58b9805fc243977e244b8e03945ea378d7cc..6d4b15860e8b054962ba46c097c8c28c45d42e34 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -3,12 +3,17 @@ import os import json +from collections.abc import Generator +from typing import Optional import pytest -from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, FunctionCall, + ToolCall) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from ..utils import models_path_prefix # Use a common model that is likely to be available @@ -38,6 +43,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], assert actual_tool_call.function == expected_tool_call.function +def stream_delta_message_generator( + xlam_tool_parser: xLAMToolParser, + xlam_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None, +) -> Generator[DeltaMessage, None, None]: + all_token_ids = xlam_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = (detokenize_incrementally( + tokenizer=xlam_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + )) + + current_text = previous_text + delta_text + + delta_message = xlam_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + def test_extract_tool_calls_no_tools(xlam_tool_parser): model_output = "This is a test" extracted_tool_calls = xlam_tool_parser.extract_tool_calls( @@ -53,6 +108,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): "single_tool_with_think_tag", "single_tool_with_json_code_block", "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -120,6 +176,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): ], "I'll check the weather for you.", ), + ( + """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "I'll help you check the weather.", + ), ], ) def test_extract_tool_calls(xlam_tool_parser, model_output, @@ -247,3 +317,147 @@ def test_streaming_with_list_structure(xlam_tool_parser): assert hasattr(result, "tool_calls") assert len(result.tool_calls) == 1 assert result.tool_calls[0].function.name == "get_current_weather" + + +@pytest.mark.parametrize( + ids=[ + "parallel_tool_calls", + "single_tool_with_think_tag", + "single_tool_with_json_code_block", + "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + }), + )), + ], + "", + ), + ( + """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "<think>I'll help you with that.</think>", + ), + ( + """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "", + ), + ( + """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "", + ), + ( + """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "I can help with that.", + ), + ], +) +def test_extract_tool_calls_streaming_incremental( + xlam_tool_parser, + xlam_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501 + request = ChatCompletionRequest(model=MODEL, messages=[], tools=[]) + + chunks = [] + for delta_message in stream_delta_message_generator( + xlam_tool_parser, xlam_tokenizer, model_output, request): + chunks.append(delta_message) + + # Should have multiple chunks + assert len(chunks) >= 3 + + # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501 + header_found = False + expected_first_tool = expected_tool_calls[0] + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].id: + header_found = True + assert (chunk.tool_calls[0].function.name == + expected_first_tool.function.name) + assert chunk.tool_calls[0].type == "function" + # Arguments may be empty initially or None + if chunk.tool_calls[0].function.arguments is not None: + # If present, should be empty string initially + assert chunk.tool_calls[0].function.arguments == "" + break + assert header_found + + # Should have chunks with incremental arguments + arg_chunks = [] + for chunk in chunks: + if (chunk.tool_calls and chunk.tool_calls[0].function.arguments + and chunk.tool_calls[0].function.arguments != "" + and chunk.tool_calls[0].index == + 0 # Only collect arguments from the first tool call + ): + arg_chunks.append(chunk.tool_calls[0].function.arguments) + + # Arguments should be streamed incrementally + assert len(arg_chunks) > 1 + + # Concatenated arguments should form valid JSON for the first tool call + full_args = "".join(arg_chunks) + parsed_args = json.loads(full_args) + expected_args = json.loads(expected_first_tool.function.arguments) + assert parsed_args == expected_args diff --git a/tests/utils.py b/tests/utils.py index 7e3b8a98054364cdc8b912342251819de8efab82..35c05186cc4172cad56b75be59541a649a308e95 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import asyncio import copy import functools import importlib +import json import os import signal import subprocess @@ -107,7 +108,8 @@ class RemoteOpenAIServer: env_dict: Optional[dict[str, str]] = None, seed: Optional[int] = 0, auto_port: bool = True, - max_wait_seconds: Optional[float] = None) -> None: + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: raise ValueError("You have manually specified the port " @@ -126,6 +128,12 @@ class RemoteOpenAIServer: vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + if override_hf_configs is not None: + vllm_serve_args = vllm_serve_args + [ + "--hf-overrides", + json.dumps(override_hf_configs) + ] + parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") subparsers = parser.add_subparsers(required=False, dest="subparser") @@ -694,9 +702,12 @@ def multi_process_parallel( os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1" ray.init( runtime_env={ - "working_dir": VLLM_PATH, - "excludes": - ["build", ".git", "cmake-build-*", "shellcheck", "dist"] + "working_dir": + VLLM_PATH, + "excludes": [ + "build", ".git", "cmake-build-*", "shellcheck", "dist", + "ep_kernels_workspace" + ] }) distributed_init_port = get_open_port() diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 69bb6e4c4f23de3efce66dc53ddc71732ab2847c..e69dc0d44194039070f111db1b6059d9f2668c37 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -5,14 +5,18 @@ import asyncio import hashlib import json +import os import pickle import socket +import tempfile from collections.abc import AsyncIterator +from pathlib import Path from unittest.mock import patch import os import pytest import torch +import yaml import zmq from transformers import AutoTokenizer from vllm_test_utils.monitor import monitor @@ -377,9 +381,9 @@ def test_duplicate_dict_args(caplog_vllm, parser): def test_supports_kw(callable,kw_name,requires_kw_only, allow_var_kwargs,is_supported): assert supports_kw( - callable=callable, - kw_name=kw_name, - requires_kw_only=requires_kw_only, + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, allow_var_kwargs=allow_var_kwargs ) == is_supported @@ -946,6 +950,36 @@ def test_join_host_port(): assert join_host_port("::1", 5555) == "[::1]:5555" +def test_json_count_leaves(): + """Test json_count_leaves function from jsontree utility.""" + from vllm.utils.jsontree import json_count_leaves + + # Single leaf values + assert json_count_leaves(42) == 1 + assert json_count_leaves("hello") == 1 + assert json_count_leaves(None) == 1 + + # Empty containers + assert json_count_leaves([]) == 0 + assert json_count_leaves({}) == 0 + assert json_count_leaves(()) == 0 + + # Flat structures + assert json_count_leaves([1, 2, 3]) == 3 + assert json_count_leaves({"a": 1, "b": 2}) == 2 + assert json_count_leaves((1, 2, 3)) == 3 + + # Nested structures + nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} + assert json_count_leaves(nested_dict) == 3 + + nested_list = [1, [2, 3], 4] + assert json_count_leaves(nested_list) == 4 + + mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} + assert json_count_leaves(mixed_nested) == 4 + + def test_convert_ids_list_to_tokens(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") token_ids = tokenizer.encode("Hello, world!") @@ -993,3 +1027,40 @@ def test_current_stream_multithread(): child_thread.join(timeout=5) if child_thread.is_alive(): pytest.fail("Child thread failed to exit properly") + + +def test_load_config_file(tmp_path): + # Define the configuration data + config_data = { + "enable-logging": True, + "list-arg": ["item1", "item2"], + "port": 12323, + "tensor-parallel-size": 4 + } + + # Write the configuration data to a temporary YAML file + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as config_file: + yaml.dump(config_data, config_file) + + # Initialize the parser + parser = FlexibleArgumentParser() + + # Call the function with the temporary file path + processed_args = parser.load_config_file(str(config_file_path)) + + # Expected output + expected_args = [ + "--enable-logging", + "--list-arg", + "item1", + "item2", + "--port", + "12323", + "--tensor-parallel-size", + "4", + ] + + # Assert that the processed arguments match the expected output + assert processed_args == expected_args + os.remove(str(config_file_path)) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index ac08b9052cd80b742cf255985a81920bb82dda77..e4c07aae0ebedf1a1bf35392d3d7b5a7ea6effd9 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, set_kv_cache_layout) from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, - _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN + _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN, + "FLEX_ATTENTION_SLOW" ] # Remove flashinfer from the list if it's not available @@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache( common_attn_metadata: CommonAttentionMetadata, randomize_blocks: bool = True) -> torch.Tensor: """Create and prepopulate a KV cache with context data. - + Args: k_contexts: List of key context tensors for each sequence v_contexts: List of value context tensors for each sequence @@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache( device: Device to create the cache on num_blocks: Total number of blocks in the cache block_table: Block table tensor to populate - randomize_blocks: Whether to randomly permute blocks + randomize_blocks: Whether to randomly permute blocks or use sequential order - + Returns: Tuple of (kv_cache, updated_block_table) """ @@ -150,15 +151,15 @@ def create_and_prepopulate_kv_cache( # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - perm = torch.randperm( - blocks_end - 1) + 1 # Random permutation starting from block 1 + # Random permutation starting from block 1 + perm = torch.randperm(blocks_end - 1) + 1 else: - perm = torch.arange( - 1, blocks_end) # Sequential order starting from block 1 + # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort( - perm) + 1 # Add 1 to account for starting from block 1 + # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] # Construct the right block table @@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, kv_cache: torch.Tensor) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + # Handle special case for FLEX_ATTENTION_SLOW + actual_backend = backend + + use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") + if backend == "FLEX_ATTENTION_SLOW": + actual_backend = _Backend.FLEX_ATTENTION + use_direct_block_mask = False + + builder_cls, impl_cls = get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if backend == _Backend.FLASHINFER_VLLM_V1: + if actual_backend == _Backend.FLASHINFER_VLLM_V1: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -239,6 +248,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + if actual_backend == _Backend.FLEX_ATTENTION: + builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -281,7 +292,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, @pytest.mark.parametrize("batch_spec_name", [ "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium" + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" ]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_backend_correctness(batch_spec_name: str, model: str): @@ -302,7 +314,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): """ batch_spec = BATCH_SPECS[batch_spec_name] vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens)) + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=8192) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -451,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-3 - if backend_name == _Backend.FLEX_ATTENTION: - atol = 5e-1 # TODO: figure out why flex_attention has such large - # numerical differences for medium_decode, medium_prefill, - # mixed_medium - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() max_rel_diff = torch.max( torch.abs(backend_output - sdpa_output) / @@ -465,12 +473,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol=rtol, atol=atol) - if not all_close: - print(f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") - print(f"[{backend_name}] output: {backend_output}") - print(f"[{backend_name}] SDPA baseline: {sdpa_output}") - assert all_close, ( f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py new file mode 100644 index 0000000000000000000000000000000000000000..59e56281494680a893e29bbae160b5d51c7b5f15 --- /dev/null +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mamba attention backend selectors.""" + +from types import SimpleNamespace + +import pytest + +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.models.minimax_text_01 import ( + MiniMaxText01LinearAttention) +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + + +@pytest.mark.parametrize( + "layer_class, init_kwargs, expected_backend, expected_mamba_type", [ + ( + MambaMixer, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + time_step_rank=8, + use_conv_bias=True, + use_bias=False, + use_rms_norm=True, + ), + Mamba1AttentionBackend, + "mamba1", + ), + ( + MambaMixer2, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + use_conv_bias=True, + use_bias=False, + n_groups=1, + num_heads=8, + head_dim=32, + ), + Mamba2AttentionBackend, + "mamba2", + ), + ( + MiniMaxText01LinearAttention, + dict( + hidden_size=128, + hidden_inner_size=256, + num_heads=8, + head_dim=32, + max_position=2048, + block_size=64, + num_hidden_layer=12, + layer_idx=0, + linear_layer_idx=0, + ), + LinearAttentionBackend, + "linear_attention", + ), + ( + ShortConv, + dict( + config=SimpleNamespace(conv_L_cache=32, conv_bias=True), + dim=128, + layer_idx=0, + ), + ShortConvAttentionBackend, + "short_conv", + ), + ]) +def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs, + expected_backend, expected_mamba_type): + """Test that Mamba-like layers return the correct attention backend.""" + layer = layer_class(**init_kwargs) + + backend_class = layer.get_attn_backend() + assert backend_class is expected_backend + assert layer.mamba_type == expected_mamba_type + + +@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [ + (MambaMixer, Mamba1AttentionBackend, "mamba1"), + (MambaMixer2, Mamba2AttentionBackend, "mamba2"), + (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), + (ShortConv, ShortConvAttentionBackend, "short_conv"), +]) +def test_mamba_layers_have_unified_interface(layer_class, expected_backend, + expected_mamba_type): + """Test that all Mamba layers have the unified get_attn_backend + interface.""" + assert hasattr(layer_class, 'get_attn_backend'), ( + f"{layer_class.__name__} should have get_attn_backend method") + assert hasattr(layer_class, 'mamba_type'), ( + f"{layer_class.__name__} should have mamba_type property") diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py deleted file mode 100644 index 4245b50c71310d0db924be0f4e1ea83b5006a33f..0000000000000000000000000000000000000000 --- a/tests/v1/attention/test_mamba_selectors.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mamba attention backend selectors.""" - -import pytest - -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend - - -@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"], - argvalues=[("mamba2", Mamba2AttentionBackend)]) -def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend): - backend_class = get_mamba_attn_backend(mamba_type) - - assert backend_class is expected_backend - - -def test_get_mamba_attn_backend_unsupported(): - unsupported_types = ["mamba", ""] - - for mamba_type in unsupported_types: - err_message = f"Mamba Attention type {mamba_type} is not supported yet." - with pytest.raises(NotImplementedError, match=err_message): - get_mamba_attn_backend(mamba_type) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py new file mode 100644 index 0000000000000000000000000000000000000000..24070358799ef75eb3eb9e13abd8389638efab40 --- /dev/null +++ b/tests/v1/attention/test_mla_backends.py @@ -0,0 +1,522 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 MLA backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, + _Backend.TRITON_MLA_VLLM_V1 +] + +# Remove CUTLASS_MLA from the list if not using sm100 +if not torch.cuda.is_available() or torch.cuda.get_device_properties( + 0).major < 10: + BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) + +torch.manual_seed(42) + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": + BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": + BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": + BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": + BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), + "medium_prefill": + BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": + BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], + query_lens=[1, 1, 1, 7, 7, 7]), + "large_decode": + BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": + BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": + BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": + BatchSpec(seq_lens=[1024], query_lens=[64]), +} + + +def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.head_size, # latent dimension + dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), + device=device, + ) + return kv_cache + + +def create_and_prepopulate_kv_cache( + kv_c_contexts: list[torch.Tensor], + k_pe_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True) -> torch.Tensor: + """Create and prepopulate an MLA KV cache with context data. + + Args: + kv_c_contexts: List of latent KV context tensors for each sequence + k_pe_contexts: List of key positional embedding context tensors + for each sequence + block_size: Size of each block + num_kv_heads: Number of KV heads (should be 1 for MLA) + head_size: Size of each head (latent dimension) + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + common_attn_metadata: Common attention metadata + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + MLA KV cache tensor + """ + batch_size = len(kv_c_contexts) + seq_lens = common_attn_metadata.seq_lens_cpu + query_lens = common_attn_metadata.query_start_loc_cpu[ + 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + context_lens = common_attn_metadata.num_computed_tokens_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty(num_blocks, + block_size, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(-1, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + start = start_block_idx * block_size + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + perm = torch.randperm( + blocks_end - 1) + 1 # Random permutation starting from block 1 + else: + perm = torch.arange( + 1, blocks_end) # Sequential order starting from block 1 + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort( + perm) + 1 # Add 1 to account for starting from block 1 + kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, + block_indices] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + + +def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, + layer_names: list[str], vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, kv_c: torch.Tensor, + k_pe: torch.Tensor, kv_cache: torch.Tensor, + kv_lora_rank: int, qk_nope_head_dim: int, + qk_rope_head_dim: int, v_head_dim: int, + mock_kv_b_proj) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + builder_cls, impl_cls = get_attention_backend(backend) + + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate MLA implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + ) + + # Process weights to create W_UK_T and W_UV attributes needed by MLA + act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + impl.process_weights_after_loading(act_dtype) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + num_tokens = query.shape[0] + output = torch.empty(num_tokens, + num_heads * v_head_dim, + dtype=query.dtype, + device=query.device) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward(mock_layer, + query, + kv_c, + k_pe, + kv_cache, + attn_metadata, + output=output) + + return output + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" +]) +@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) +def test_backend_correctness(dist_init, batch_spec_name: str, model: str): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model, + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=2048) + device = torch.device("cuda:0") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_nope_head_dim = 128 + v_head_dim = 128 + total_head_size = kv_lora_rank + qk_rope_head_dim + assert kv_lora_rank + qk_rope_head_dim == head_size, \ + f"MLA dimensions don't match: {total_head_size} != {head_size}" + scale = 1.0 / (total_head_size**0.5) + + # 2. Generate data and compute SDPA reference output for MLA + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + all_sdpa_outputs = [] + kv_c_contexts, k_pe_contexts = [], [] + + # Create shared MLA weight matrices for consistency across all sequences + W_UK = torch.randn(kv_lora_rank, + num_q_heads, + qk_nope_head_dim, + dtype=dtype, + device=device) + W_UV = torch.randn(kv_lora_rank, + num_q_heads, + v_head_dim, + dtype=dtype, + device=device) + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate MLA tensors + # Q has both nope and rope components: + # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim] + q_c = torch.randn(q_len, + num_q_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device) + + # KV_C (latent K/V): [s_len, kv_lora_rank] + kv_c_full = torch.randn(s_len, + kv_lora_rank, + dtype=dtype, + device=device) + + # K_PE (rope component): [s_len, 1, qk_rope_head_dim] + k_pe_full = torch.randn(s_len, + 1, + qk_rope_head_dim, + dtype=dtype, + device=device) + + # Determine if this is decode (single token) + # or prefill (multiple tokens) + is_decode = q_len == 1 + + # Split q into nope and rope components + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + + if is_decode: + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] + + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] + + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + else: + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, + kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) + + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, + s_len, + dtype=torch.bool, + device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + attn_mask=attn_mask, + scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + + all_sdpa_outputs.append(sdpa_out_i) + + # Inputs for vLLM MLA backends are just the new tokens + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[context_len:]) # New kv_c tokens + all_k_pe_vllm.append(k_pe_full[context_len:]) # New k_pe tokens + + # Contextual K/V data used to populate the paged cache (MLA format) + kv_c_contexts.append(kv_c_full[:context_len]) + k_pe_contexts.append(k_pe_full[:context_len]) + + # Concatenate all sequences (no reordering needed) + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + # Create mock kv_b_proj using the same weights as reference implementation + from vllm.model_executor.layers.linear import ColumnParallelLinear + mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, + output_size=num_q_heads * + (qk_nope_head_dim + v_head_dim), + bias=False).to(device=device, + dtype=dtype) + + # Set the mock weights to match our reference implementation + # Reshape W_UK and W_UV to match the expected kv_b_proj format + # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim] + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) + + # Create metadata using original batch spec + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True) + + # 4. Run vLLM backends and compare + for backend_name in BACKENDS_TO_TEST: + backend_output = run_attention_backend( + backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, + common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, + kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, + mock_kv_b_proj) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {sdpa_output.shape}") + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {sdpa_output.dtype}") + + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values") + + # Check numerical similarity + rtol = 1e-2 + atol = 5e-1 + + max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_rel_diff = torch.max( + torch.abs(backend_output - sdpa_output) / + torch.abs(sdpa_output)).item() + all_close = torch.allclose(backend_output, + sdpa_output, + rtol=rtol, + atol=atol) + + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index a4e38eb32f6a1a0ff66a62cb3687591a7045bc7e..6a08cdc56f7367323995fe483a51fc9e4fc1e91c 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -58,6 +58,7 @@ def create_common_attn_metadata( dtype=torch.int32, device=device) seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) # Create computed tokens (context length for each sequence) context_lens = [ @@ -101,6 +102,7 @@ def create_common_attn_metadata( num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, causal=True, @@ -133,6 +135,12 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", _Backend.XFORMERS_VLLM_V1: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", + _Backend.CUTLASS_MLA: + "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", + _Backend.FLASHMLA_VLLM_V1: + "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.TRITON_MLA_VLLM_V1: + "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } if backend_name not in backend_map: @@ -165,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", tensor_parallel_size: int = 1, max_model_len: int = 1024, dtype: Union[ModelDType, torch.dtype] = "auto", + num_gpu_blocks: int = 1000, block_size: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, add_mock_model_methods: bool = True) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" @@ -187,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ) # Set cache blocks for testing # (these may be set during initialization normally) - cache_config.num_gpu_blocks = 1000 + cache_config.num_gpu_blocks = num_gpu_blocks cache_config.num_cpu_blocks = 0 parallel_config = ParallelConfig( @@ -196,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, ) device_config = DeviceConfig() diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 3a9492269f9c9e4fcb4ceeac8e4d2072e38bffb3..c153e38fe3df36966c0c9b82cb57a848d3f27176 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -22,7 +22,6 @@ def _make_model_runner_output( for i, req_id in enumerate(req_ids) }, sampled_token_ids=[[i] for i in range(len(req_ids))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5b751f45a4b13fff46f75c937c4169938841a5 --- /dev/null +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.core.encoder_cache_manager import EncoderCacheManager + + +# ------------------ Mock Classes ------------------ # +class MockRequest: + + def __init__(self, request_id, mm_hashes, token_counts): + self.request_id = request_id + self.mm_hashes = mm_hashes + self._token_counts = token_counts + + def get_num_encoder_tokens(self, input_id: int) -> int: + return self._token_counts[input_id] + + +# ------------------ Unit Tests ------------------ # +def test_basic_allocate_and_reuse(): + cache = EncoderCacheManager(cache_size=10) + req = MockRequest("r1", ["imgA"], [4]) + + assert not cache.check_and_update_cache(req, 0) + assert cache.can_allocate(req, 0, int(1e9), 0) + + cache.allocate(req, 0) + + assert cache.check_and_update_cache(req, 0) + assert "r1" in cache.cached["imgA"] + assert cache.num_free_slots == 6 + + # Free twice to bring refcount to 0. + cache.free_encoder_input(req, 0) + cache.free_encoder_input(req, 0) + + assert not cache.cached["imgA"] + assert "imgA" in cache.freeable + assert cache.num_freeable_slots == 10 + assert cache.num_free_slots == 6 + + +def test_freeing_decreases_refcount_and_moves_to_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req2", ["img3"], [5]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert len(manager.cached["img3"]) == 1 + + manager.free_encoder_input(req, 0) + + assert not manager.cached["img3"] + assert "img3" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_free_request_frees_all_inputs(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req3", ["a", "b"], [2, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 1, int(1e9), 0) + manager.allocate(req, 1) + + assert len(manager.cached["a"]) == 1 + assert len(manager.cached["b"]) == 1 + + manager.free(req) + + assert not manager.cached["a"] + assert not manager.cached["b"] + assert "a" in manager.freeable + assert "b" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_eviction_when_cache_is_full(): + manager = EncoderCacheManager(cache_size=10) + + req1 = MockRequest("req1", ["x"], [6]) + req2 = MockRequest("req2", ["y"], [5]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + # 'x' should have been evicted. + assert "x" not in manager.cached + assert "x" in manager.get_freed_mm_hashes() + + +def test_get_cached_input_ids(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 2, int(1e9), 0) + manager.allocate(req, 2) + + cached_ids = manager.get_cached_input_ids(req) + assert cached_ids == {0, 2} + + +def test_has_cache_restores_from_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqY", ["imgZ"], [4]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + manager.free_encoder_input(req, 0) + + # Should restore from freeable. + assert manager.check_and_update_cache(req, 0) + assert len(manager.cached["imgZ"]) == 1 + assert "imgZ" not in manager.freeable + assert manager.num_freeable_slots == 6 + + +def test_get_freed_mm_hashes_clears_freed_list(): + manager = EncoderCacheManager(cache_size=10) + req1 = MockRequest("reqA", ["a"], [5]) + req2 = MockRequest("reqB", ["b"], [6]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + # Should trigger eviction of 'a'. + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + freed = manager.get_freed_mm_hashes() + assert "a" in freed + assert manager.get_freed_mm_hashes() == [] + + +def test_schedule_request_multi_images_respect_space_limit(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 100 + + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, + num_tokens_to_schedule) + + +def test_schedule_request_multi_images_respect_compute_limit(): + manager = EncoderCacheManager(cache_size=100) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 10 + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, + num_tokens_to_schedule) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index b9d07a40e961984b74ee3bd66167a392aae90980..d4c98f50d1435fba84bb7b4e523e9d3203751df8 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -8,7 +8,8 @@ import pytest import torch from vllm.config import ModelConfig, SchedulerConfig, VllmConfig -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -39,17 +40,20 @@ def make_request( mm_hashes: Optional[list[str]] = None, cache_salt: Optional[str] = None, ): - if mm_positions is None: - mm_kwargs = None - else: - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_positions) + mm_features = [] + if mm_positions is not None: + for j, position in enumerate(mm_positions): + identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) return Request(request_id=request_id, prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, + mm_features=mm_features if mm_features else None, sampling_params=SamplingParams(max_tokens=17), pooling_params=None, eos_token_id=100, @@ -599,8 +603,14 @@ def test_unify_kv_cache_configs(): ] unify_kv_cache_configs(need_sort_kv_cache_config) - assert need_sort_kv_cache_config[0].num_blocks == 10 - assert need_sort_kv_cache_config[1].num_blocks == 10 + sorted_kv_cache_groups = [ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + ] + assert ( + need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups) + assert ( + need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups) diff_kv_cache_config = [ KVCacheConfig( diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 89824768ed909b309aa31b91a68fa5d86ab240ef..e7a8f63702b300127f7960ac822cde5dd3775ea6 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -9,7 +9,8 @@ import pytest import torch from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.block_pool import BlockPool @@ -32,17 +33,20 @@ def make_request( prompt_logprobs: Optional[int] = None, cache_salt: Optional[str] = None, ): - if mm_positions is None: - mm_kwargs = None - else: - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_positions) + mm_features = [] + if mm_positions is not None: + for j, position in enumerate(mm_positions): + identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) return Request(request_id=request_id, prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, + mm_features=mm_features if mm_features else None, sampling_params=SamplingParams( max_tokens=17, prompt_logprobs=prompt_logprobs), pooling_params=None, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index a9048f2d4a384bd16ca901c4f6cc6cc665e1dc59..e78863512b315a95a5235dccd0c786dd2e96e09e 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -9,13 +9,14 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager @@ -161,7 +162,6 @@ def test_schedule_partial_requests(): # Only the first request has a sampled token id because # the rest requests are still being prefilled. sampled_token_ids=[[0], [], []], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -212,7 +212,6 @@ def test_no_mm_input_chunking(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -276,7 +275,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -301,7 +299,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -345,7 +342,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) @@ -358,7 +355,6 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -399,7 +395,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -412,7 +408,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -452,7 +447,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -465,7 +460,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -500,7 +494,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) @@ -508,7 +502,6 @@ def test_stop_via_update_from_output(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -557,7 +550,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -575,7 +567,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -611,7 +602,6 @@ def test_preempt_during_execution(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -629,7 +619,6 @@ def test_preempt_during_execution(): req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[42]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -685,13 +674,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], - spec_token_ids=spec_tokens, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) + draft_token_ids = DraftTokenIds(req_ids, spec_tokens) + scheduler.update_draft_token_ids(draft_token_ids) for i in range(len(requests)): running_req = scheduler.running[i] @@ -725,7 +715,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=output_tokens, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -854,7 +843,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -901,7 +889,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -969,7 +956,6 @@ def test_kv_connector_unable_to_allocate(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1008,133 +994,133 @@ def test_kv_connector_unable_to_allocate(): assert len(scheduler.waiting) == 0 -# def test_kv_connector_handles_preemption(): -# """ -# Test whether scheduler with KVConnector is able to handle -# unable to allocate (run out of blocks in allocate_slots(). -# """ - -# # Setup Scheduler With Mock External Cache Hit. -# BLOCK_SIZE = 2 -# # NOTE: there is 1 null block, so this is 6 blocks. -# NUM_BLOCKS = 7 -# scheduler = create_scheduler( -# enable_prefix_caching=True, -# use_kv_connector=True, -# block_size=BLOCK_SIZE, -# num_blocks=NUM_BLOCKS, -# ) - -# NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE -# scheduler.connector.get_num_new_matched_tokens = Mock(name="method") -# scheduler.connector.get_num_new_matched_tokens.return_value = ( -# NUM_MATCHED_NEW_TOKENS, False) - -# # Create two requests. -# # Both can be scheduled at first, but the second request -# # will be preempted and re-scheduled. -# NUM_REQUESTS = 2 -# NUM_TOKENS = BLOCK_SIZE * 2 + 1 -# MAX_TOKENS = BLOCK_SIZE * 2 -# requests = create_requests(num_requests=NUM_REQUESTS, -# num_tokens=NUM_TOKENS, -# max_tokens=MAX_TOKENS, -# block_size=BLOCK_SIZE) -# req_ids = [] -# req_to_index = {} -# for i, request in enumerate(requests): -# scheduler.add_request(request) -# req_ids.append(request.request_id) -# req_to_index[request.request_id] = i - -# MODEL_RUNNER_OUTPUT = ModelRunnerOutput( -# req_ids=req_ids, -# req_id_to_index=req_to_index, -# sampled_token_ids=[[1000]] * len(req_ids), -# spec_token_ids=None, -# logprobs=None, -# prompt_logprobs_dict={}, -# pooler_output=[], -# ) - -# # All can be scheduled - 1st token. -# output = scheduler.schedule() -# _assert_right_scheduler_output( -# output, -# # 2 remote kv cache hits. -# num_requests=2, -# expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS) -# assert len(scheduler.running) == 2 -# _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) - -# # All can be scheduled - 2nd token. -# output = scheduler.schedule() -# _assert_right_scheduler_output( -# output, -# # no connector_metadata -# num_requests=0, -# expected_num_scheduled_tokens=1) -# assert len(scheduler.running) == 2 -# _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) - -# # This will generate a new block and cause a preemption - 3rd token. -# output = scheduler.schedule() -# _assert_right_scheduler_output( -# output, -# # no connector_metadata -# num_requests=0, -# expected_num_scheduled_tokens=1) -# assert len(scheduler.running) == 1 -# assert len(scheduler.waiting) == 1 -# _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) -# assert len(scheduler.running) == 1 -# assert len(scheduler.waiting) == 1 - -# # Only 1 can be scheduled - 4th (and last token). -# output = scheduler.schedule() -# _assert_right_scheduler_output( -# output, -# # no connector_metadata -# num_requests=0, -# expected_num_scheduled_tokens=1) -# assert len(scheduler.waiting) == 1 -# assert len(scheduler.running) == 1 -# _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) -# assert len(scheduler.running) == 0 -# # All memory should be freed since nothing is running. -# assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ -# == NUM_BLOCKS - 1 - -# # Restarts the preempted request - generate 3rd token. -# # This will have a local and remote cache hit. -# output = scheduler.schedule() -# _assert_right_scheduler_output( -# output, -# # 1 remote kv_cache hit! -# num_requests=1, -# # Only 1 block was preempted and there is a single -# # remote hit. So only single new token is scheduled. -# expected_num_scheduled_tokens=1, -# ) -# assert len(scheduler.running) == 1 -# assert len(scheduler.waiting) == 0 -# _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) -# assert len(scheduler.running) == 1 -# assert len(scheduler.waiting) == 0 - -# # Only 1 can be scheduled - 4th (and last token). -# output = scheduler.schedule() -# _assert_right_scheduler_output( -# output, -# # no connector_metadata -# num_requests=0, -# expected_num_scheduled_tokens=1) -# assert len(scheduler.running) == 1 -# _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) -# assert len(scheduler.running) == 0 -# # All memory should be freed since nothing is running. -# assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ -# == NUM_BLOCKS - 1 +def test_kv_connector_handles_preemption(): + """ + Test whether scheduler with KVConnector is able to handle + unable to allocate (run out of blocks in allocate_slots(). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 2 if not current_platform.is_rocm() else 64 + # NOTE: there is 1 null block, so this is 6 blocks. + NUM_BLOCKS = 7 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + ) + + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS, False) + + # Create two requests. + # Both can be scheduled at first, but the second request + # will be preempted and re-scheduled. + NUM_REQUESTS = 2 + NUM_TOKENS = BLOCK_SIZE * 2 + 1 + MAX_TOKENS = BLOCK_SIZE * 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + # All can be scheduled - 1st token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 2 remote kv cache hits. + num_requests=2, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # All can be scheduled - 2nd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # This will generate a new block and cause a preemption - 3rd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + + # Restarts the preempted request - generate 3rd token. + # This will have a local and remote cache hit. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 1 remote kv_cache hit! + num_requests=1, + # Only 1 block was preempted and there is a single + # remote hit. So only single new token is scheduled. + expected_num_scheduled_tokens=1, + ) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + def make_output(scheduler: Scheduler): @@ -1145,7 +1131,6 @@ def make_output(scheduler: Scheduler): for i, req in enumerate(scheduler.running) }, sampled_token_ids=[[1000]] * len(scheduler.running), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1313,7 +1298,8 @@ def create_requests_with_priority( mm_positions: Optional[list[list[PlaceholderRange]]] = None, max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): + prompt_logprobs: Optional[int] = None, + starting_idx: int = 0): """Create requests with specified priorities and arrival times.""" assert len(priorities) == num_requests if arrival_times is not None: @@ -1327,21 +1313,24 @@ def create_requests_with_priority( prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): + mm_features = [] if mm_positions is not None: mm_position = mm_positions[i] - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_position) - else: - mm_position = None - mm_kwargs = None + for j, position in enumerate(mm_position): + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) + request = Request( - request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, + request_id=f"{i + starting_idx}", + prompt_token_ids=[i + starting_idx] * num_tokens, sampling_params=sampling_params, pooling_params=None, - multi_modal_kwargs=mm_kwargs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, + mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, arrival_time=arrival_times[i], priority=priorities[i], @@ -1471,7 +1460,6 @@ def test_priority_scheduling_preemption(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1544,7 +1532,6 @@ def test_priority_scheduling_no_preemption_when_space_available(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1786,7 +1773,6 @@ def test_priority_scheduling_heap_property(): req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, sampled_token_ids=[[100]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1823,9 +1809,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): request = Request( request_id="0", prompt_token_ids=[0, 1], - multi_modal_kwargs=None, - multi_modal_hashes=None, - multi_modal_placeholders=None, + mm_features=None, sampling_params=sampling_params, pooling_params=None, eos_token_id=EOS_TOKEN_ID, @@ -1836,3 +1820,87 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(output.scheduled_new_reqs) == 0 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 + + +def test_priority_scheduling_preemption_when_out_of_kv(): + """Test that priority scheduling preempts lower priority requests + when out of KV cache space.""" + # Create scheduler with very limited memory to force preemption + scheduler = create_scheduler_with_priority( + max_num_seqs=2, # Allow multiple requests + max_num_batched_tokens=200, + num_blocks=5, # Can hold 64 tokens (first block is null) + block_size=16, # Standard block size + ) + + # Create a request and schedule it + request_low = create_requests_with_priority( + num_requests=1, + priorities=[1], + arrival_times=[0.0], + num_tokens=30, + starting_idx=0, + )[0] + scheduler.add_request(request_low) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Simulate model execution + model_output = ModelRunnerOutput( + req_ids=[request_low.request_id], + req_id_to_index={request_low.request_id: 0}, + sampled_token_ids=[[100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Create a high priority request and schedule it + request_high = create_requests_with_priority( + num_requests=1, + priorities=[0], + arrival_times=[1.0], + num_tokens=32, + starting_idx=1, + )[0] + scheduler.add_request(request_high) + output = scheduler.schedule() + # KV cache should be full at this point + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 2 + + # Simulate model execution + requests = [request_low, request_high] + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[100] for _ in requests], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Schedule again - this should trigger preemption + # req_low needs 32 tokens = 2 blocks + # req_high needs 33 tokens = 3 blocks + # so doesn't fit in 4 blocks. + output = scheduler.schedule() + + # Should have preempted req_low + assert len(output.scheduled_new_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 \ No newline at end of file diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 849c3f59ae527df75d6312b0e71e434c6169fc7a..e392c2c336e9beaa109aa0dd24f70ac054b6cae9 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -6,7 +6,8 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) @@ -139,15 +140,20 @@ def create_requests( prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): + mm_features = [] if mm_positions is not None: mm_position = mm_positions[i] - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_position) - mm_hashes = ["hash"] * len(mm_position) - else: - mm_position = None - mm_kwargs = None - mm_hashes = None + for j, position in enumerate(mm_position): + # Dummy hash for each mm item should be unique + # since encoder cache tracks entries by hash + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) + prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * num_tokens) request = Request( @@ -155,9 +161,7 @@ def create_requests( prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, - multi_modal_kwargs=mm_kwargs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=mm_hashes, + mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, block_hasher=block_hasher, ) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index d72e50e5196b84fade181c19852f95e6582b5556..6bc9b2b1d82d24fc4bf54c3870c309e98b4eb403 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Optional, Union import pytest import torch @@ -10,12 +9,6 @@ import torch from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel from vllm.distributed import cleanup_dist_env_and_memory -from vllm.forward_context import get_forward_context -from vllm.model_executor.models.gemma3n_mm import ( - Gemma3nForConditionalGeneration) -from vllm.model_executor.models.registry import ModelRegistry -from vllm.model_executor.models.utils import extract_layer_index -from vllm.sequence import IntermediateTensors from ...utils import fork_new_process_for_each_test @@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test SEED = 42 -class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = super().forward(input_ids, positions, - intermediate_tensors, inputs_embeds, - **kwargs) - attn_metadata = get_forward_context().attn_metadata - # attn_metadata is None during dummy runs - if (attn_metadata is not None - and self.language_model.cache_config.kv_sharing_fast_prefill): - assert isinstance(attn_metadata, dict) # true in V1 - # Gemma3n-E2B has 30 layers, with last 20 layers being - # cross-decoder layers. Check attention metadata is correct - for layer_name, metadata in attn_metadata.items(): - layer_idx = extract_layer_index(layer_name) - if layer_idx >= 20: - assert hasattr(metadata, 'logits_indices_padded') - assert hasattr(metadata, 'num_logits_indices') - else: - assert not hasattr(metadata, 'logits_indices_padded') - assert not hasattr(metadata, 'num_logits_indices') - - # Last layer will be a KV sharing layer - layer_attn_metadata = attn_metadata[ - self.language_model.model.layers[-1].self_attn.attn.layer_name] - logits_indices_padded = (layer_attn_metadata.logits_indices_padded) - assert logits_indices_padded is not None - num_logits_indices = layer_attn_metadata.num_logits_indices - assert num_logits_indices > 0 - # Reset hidden states to random values and - # only set logits at logits_indices to valid values - # Because logits_indices are the only positions that are used - # for output token sampling, this still produces same outputs - logits_hs = hidden_states[logits_indices_padded] - hidden_states = torch.randn_like(hidden_states) - gen_indices = logits_indices_padded[:num_logits_indices] - hidden_states[gen_indices] = logits_hs[:num_logits_indices] - - return hidden_states - - @pytest.fixture def test_prompts(): """ @@ -119,13 +64,12 @@ def cleanup(llm: LLM, compilation_config: CompilationConfig): @fork_new_process_for_each_test @pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, test_prompts: list[str], ): - ModelRegistry.register_model("Gemma3nForConditionalGeneration", - TestGemma3nForConditionalGeneration) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( # This allows vLLM compilation backend to handle allocating and diff --git a/tests/v1/e2e/test_min_tokens.py b/tests/v1/e2e/test_min_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..f013425cb59dfe6635b6fecdd8768ea4e483e853 --- /dev/null +++ b/tests/v1/e2e/test_min_tokens.py @@ -0,0 +1,479 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Comprehensive end-to-end tests for `min_tokens` in the V1 engine. + +Addresses #21950: verify and add CI coverage. + +Covers: +1) Basic functionality +2) Stop strings with `min_tokens` (bug #21987; fix in PR #22014) +3) EOS behavior with `min_tokens` (potential logits-processor bug) +4) Edge cases (min_tokens == max_tokens, min_tokens == 0) +5) Multiple stop conditions +""" + +import os +from typing import Optional, Union + +import pytest + +from vllm import LLM, SamplingParams +from vllm.outputs import RequestOutput + +# Test configuration +TEST_MODEL = "facebook/opt-125m" # Small model for fast CI execution +GREEDY = 0.0 # Deterministic generation for consistent testing + + +class MinTokensTestCase: + """Data class for min_tokens test scenarios""" + + def __init__( + self, + name: str, + min_tokens: int, + max_tokens: int, + stop: Optional[Union[str, list[str]]] = None, + expected_min_len: Optional[int] = None, + expected_exact_len: Optional[int] = None, + ): + self.name = name + self.min_tokens = min_tokens + self.max_tokens = max_tokens + self.stop = stop + self.expected_min_len = expected_min_len or min_tokens + self.expected_exact_len = expected_exact_len + + def __str__(self): + return (f"{self.name}: min={self.min_tokens}, " + f"max={self.max_tokens}, stop={self.stop}") + + +# Test scenarios covering all critical cases +MIN_TOKENS_TEST_CASES = [ + # === BASIC FUNCTIONALITY (should work) === + MinTokensTestCase(name="basic_min_tokens_no_stop", + min_tokens=8, + max_tokens=20, + stop=None, + expected_min_len=8), + MinTokensTestCase(name="min_tokens_zero", + min_tokens=0, + max_tokens=10, + stop=None, + expected_min_len=0), + MinTokensTestCase(name="min_equals_max_no_stop", + min_tokens=15, + max_tokens=15, + stop=None, + expected_exact_len=15), + + # === STOP STRINGS WITH MIN_TOKENS === + # These tests expose the detokenizer bug where stop strings + # bypass min_tokens + # Using mathematically guaranteed approach with wide stop nets + pytest.param( + MinTokensTestCase( + name="min_tokens_with_comprehensive_stops", + min_tokens=5, + max_tokens=20, + stop=[ + "a", + "e", + "i", + "o", + "u", + "t", + "n", + "s", + "r", + "l", + " ", + ], + expected_min_len=5, + ), + marks=pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False), + id="min_tokens_with_comprehensive_stops", + ), + pytest.param( + MinTokensTestCase( + name="min_tokens_with_simple_char_stop", + min_tokens=3, + max_tokens=15, + stop=["e", "a", " "], + expected_min_len=3, + ), + marks=pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False), + id="min_tokens_with_simple_char_stop", + ), + + # === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) === + # These test the MinTokensLogitsProcessor handling of EOS tokens + pytest.param( + MinTokensTestCase( + name="min_equals_max_eos_only", + min_tokens=20, + max_tokens=20, + stop=None, # Relies on default EOS token behavior + expected_exact_len=20, + ), + marks=pytest.mark.xfail( + reason= + ("Potential logits-processor bug: EOS tokens may bypass min_tokens" + ), + strict=False, + ), + id="min_equals_max_eos_only", + ), + + # === EDGE CASES === + MinTokensTestCase(name="large_min_tokens", + min_tokens=50, + max_tokens=60, + stop=None, + expected_min_len=50), + MinTokensTestCase( + name="min_tokens_with_empty_stop_list", + min_tokens=5, + max_tokens=15, + stop=[], # Empty stop list + expected_min_len=5), +] + + +@pytest.fixture(scope="module") +def llm_v1(): + """Create V1 LLM instance for testing""" + # Ensure V1 engine is used + os.environ["VLLM_USE_V1"] = "1" + + llm = LLM( + model=TEST_MODEL, + tensor_parallel_size=1, + max_model_len=1024, # Small context for fast testing + enforce_eager=True, # Avoid graph compilation overhead + ) + return llm + + +def get_token_count(output: RequestOutput) -> int: + """Extract token count from LLM output""" + if not output.outputs: + return 0 + return len(output.outputs[0].token_ids) + + +def assert_min_tokens_satisfied(output: RequestOutput, + test_case: MinTokensTestCase) -> None: + """Assert that min_tokens requirement is satisfied""" + token_count = get_token_count(output) + stop_reason = (output.outputs[0].stop_reason + if output.outputs else "no output") + + if test_case.expected_exact_len is not None: + # Exact length requirement + assert token_count == test_case.expected_exact_len, ( + f"Expected exactly {test_case.expected_exact_len} tokens, " + f"got {token_count} tokens. " + f"Stop reason: {stop_reason}") + else: + # Minimum length requirement + assert token_count >= (test_case.expected_min_len or 0), ( + f"Expected at least {test_case.expected_min_len} tokens, " + f"got {token_count} tokens. " + f"Stop reason: {stop_reason}") + + +@pytest.mark.parametrize( + "test_case", + MIN_TOKENS_TEST_CASES, + ids=lambda tc: tc.name, +) +def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): + """ + Comprehensive test for min_tokens functionality in V1 engine. + + This test covers all critical scenarios for min_tokens: + - Basic functionality (should work) + - Stop strings with min_tokens (known bug) + - EOS tokens with min_tokens (potential bug) + - Edge cases + + Args: + llm_v1: V1 LLM instance + test_case: Test scenario parameters + """ + # Known failing cases are handled via param-level xfail marks above. + + # Create sampling parameters + sampling_params = SamplingParams( + min_tokens=test_case.min_tokens, + max_tokens=test_case.max_tokens, + stop=test_case.stop, + temperature=GREEDY, + include_stop_str_in_output=True # Include stop strings for debugging + ) + + # Use simple prompt. Comprehensive stop lists should catch any generation + prompt = "Hello" + + # Generate output + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1, "Expected exactly one output" + output = outputs[0] + + # Debug information + token_count = get_token_count(output) + generated_text = output.outputs[0].text if output.outputs else "" + stop_reason = output.outputs[0].stop_reason if output.outputs else "unknown" + + print(f"\nTest: {test_case.name}") + print(f"Generated {token_count} tokens") + print(f"Stop reason: {stop_reason}") + print(f"Generated text: {repr(generated_text)}") + print(f"Expected min: {test_case.expected_min_len}") + if test_case.expected_exact_len: + print(f"Expected exact: {test_case.expected_exact_len}") + + # Validate min_tokens requirement + assert_min_tokens_satisfied(output, test_case) + + +def test_min_tokens_basic_functionality(llm_v1: LLM): + """ + Test basic min_tokens functionality without stop conditions. + + This is a baseline test that should always pass and validates + that min_tokens works correctly in the simple case. + """ + sampling_params = SamplingParams(min_tokens=10, + max_tokens=20, + temperature=GREEDY) + + prompt = "Once upon a time" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + + assert token_count >= 10, f"Expected at least 10 tokens, got {token_count}" + assert token_count <= 20, f"Expected at most 20 tokens, got {token_count}" + + +@pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False, +) +def test_min_tokens_stop_strings_bug(llm_v1: LLM): + """ + Test the specific bug where stop strings bypass min_tokens. + + This test specifically reproduces the bug Calvin is fixing in PR #22014. + It should fail until that fix is merged. + + Strategy: Use guaranteed stop characters that will appear + in any generated text. + """ + # If the bug is fixed upstream, this test will XPASS + + sampling_params = SamplingParams( + min_tokens=15, + max_tokens=50, + # Common letter; likely appears early + stop=["e"], + temperature=GREEDY, + include_stop_str_in_output=True) + + # Simple prompt that will generate text containing "e" + prompt = "The quick brown fox" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" + + # Debug info to understand what happened + print(f"Generated text: {repr(generated_text)}") + print(f"Token count: {token_count}") + print(f"Contains 'e': {'e' in generated_text}") + + # This assertion should fail due to the bug - if stop string is found early, + # the model should still continue generating until min_tokens is reached + stop_reason = (outputs[0].outputs[0].stop_reason + if outputs[0].outputs else "no output") + assert token_count >= 15, ("Bug confirmed: " + f"{token_count} tokens < min_tokens=15. " + f"Reason: {stop_reason}. " + f"Text: {repr(generated_text)}") + + +@pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False, +) +def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): + """ + Guaranteed test for stop strings bypassing min_tokens bug. + + Strategy: Use very low temperature and multiple common stop strings + to virtually guarantee early detection, combined with long min_tokens + to ensure the bug is exposed regardless of model behavior. + """ + # If the bug is fixed upstream, this test will XPASS + + sampling_params = SamplingParams( + min_tokens=50, # Set high min_tokens to ensure bug detection + max_tokens=200, + # Use multiple very common patterns - at least one will appear + stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"], + temperature=GREEDY, + include_stop_str_in_output=True) + + # Simple prompt that will generate some text + prompt = "The cat" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" + stop_reason = (outputs[0].outputs[0].stop_reason + if outputs[0].outputs else "unknown") + + print(f"Generated text: {repr(generated_text)}") + print(f"Token count: {token_count}") + print(f"Stop reason: {stop_reason}") + + # With the bug, this will fail because ANY of the common characters + # will trigger early termination before min_tokens=50 is reached + # It's virtually impossible to generate 50 tokens without hitting + # at least one of: e, a, i, o, u, space, t, n, s, r + finish_reason = (outputs[0].outputs[0].finish_reason + if outputs[0].outputs else "unknown") + + print(f"Finish reason: {finish_reason}") + + if finish_reason == "stop": + assert token_count >= 50, ("Bug confirmed: " + f"{token_count} tokens < min_tokens=50. " + f"Reason: {finish_reason}. " + f"Text: {repr(generated_text)}") + + +@pytest.mark.xfail( + reason=( + "Potential logits-processor bug: EOS tokens may bypass min_tokens"), + strict=False, +) +def test_min_tokens_eos_behavior(llm_v1: LLM): + """ + Verify EOS handling with and without min_tokens. + + - Without min_tokens: expect early EOS -> finish_reason == "stop", + stop_reason is None, and generated tokens < max_tokens (25). + - With min_tokens: EOS should be blocked until min_tokens is reached + (finish_reason == "length"); verify that eos_token_id does not appear + in generated token_ids. + """ + # tokenizer + eos id + tokenizer = llm_v1.get_tokenizer() + eos_token_id = tokenizer.eos_token_id + + prompt = "Give a file extension." + max_toks = 32 + + # Case 1: WITHOUT min_tokens + sp_no_min = SamplingParams( + max_tokens=max_toks, + temperature=GREEDY, + ) + out_no_min = llm_v1.generate([prompt], sp_no_min) + assert len(out_no_min) == 1 + choice_no_min = out_no_min[0].outputs[0] + + ids_no_min = choice_no_min.token_ids or [] + finish_no_min = choice_no_min.finish_reason + stop_no_min = choice_no_min.stop_reason + + print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min, + " stop_reason=", stop_no_min) + + assert finish_no_min == "stop", ( + f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}" + ) + assert stop_no_min is None, ( + "For EOS-based stop (no user stop strings), stop_reason should be None." + ) + assert len(ids_no_min) < max_toks, ( + f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}") + + # Case 2: WITH min_tokens + sp_with_min = SamplingParams( + min_tokens=max_toks, + max_tokens=max_toks, + temperature=GREEDY, + ) + out_with_min = llm_v1.generate([prompt], sp_with_min) + assert len(out_with_min) == 1 + choice_with_min = out_with_min[0].outputs[0] + + ids_with_min = choice_with_min.token_ids or [] + finish_with_min = choice_with_min.finish_reason + stop_with_min = choice_with_min.stop_reason + + print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min, + " stop_reason=", stop_with_min) + + # Exact length reached; EOS should have been blocked + assert len(ids_with_min) == max_toks, ( + f"Expected exactly {max_toks} tokens with min_tokens; " + f"got {len(ids_with_min)}") + assert finish_with_min == "length", ( + f"Expected finish_reason 'length'; got {finish_with_min}") + assert eos_token_id not in ids_with_min, ( + "EOS token id should not appear when min_tokens prevents early EOS.") + + +def test_min_tokens_validation(): + """ + Test that SamplingParams correctly validates min_tokens parameters. + + This tests the parameter validation logic in SamplingParams. + """ + # Valid cases + SamplingParams(min_tokens=0, max_tokens=10) + SamplingParams(min_tokens=5, max_tokens=10) + SamplingParams(min_tokens=10, max_tokens=10) + + # Invalid cases + with pytest.raises( + ValueError, + match="min_tokens must be greater than or equal to 0", + ): + SamplingParams(min_tokens=-1, max_tokens=10) + + with pytest.raises( + ValueError, + match="min_tokens must be less than or equal to max_tokens", + ): + SamplingParams(min_tokens=15, max_tokens=10) + + +if __name__ == "__main__": + """ + Run tests locally for development. + + Usage: + cd vllm/ + VLLM_USE_V1=1 python -m pytest tests/v1/e2e/test_min_tokens.py -v + """ + pytest.main([__file__, "-v"]) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 0f3dd00783b7d19b5e30f54f32f73913fde06739..0a9ed2f16106449e45b78d7a38622e6dc445c723 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -149,6 +149,8 @@ def test_ngram_correctness( os.path.join(models_path_prefix, "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"), 4), True, marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 @@ -156,7 +158,8 @@ def test_ngram_correctness( "llama3_eagle", "llama3_eagle3", "llama4_eagle", - "llama4_eagle_mm" + "llama4_eagle_mm", + "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @@ -182,6 +185,7 @@ def test_eagle_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) if (attn_backend == "TRITON_ATTN_VLLM_V1" diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 6a9749dcf3cd3b7c5598e1f17c68de229453867f..f609372e81e6114bcd8b3197bf2b6b65ca1bd4b5 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -37,9 +37,7 @@ def make_request() -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, sampling_params=SamplingParams(), pooling_params=None, eos_token_id=None, @@ -310,17 +308,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): # Schedule Batch 1: (10, req0) assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 1 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 10 # num_computed_tokens should have been updated immediately. assert engine_core.scheduler.requests[ req0.request_id].num_computed_tokens == 10 # Schedule Batch 2: (2, req0), (8, req1) - assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert engine_core.step_with_batch_queue()[0] == {} + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 2 assert scheduler_output.num_scheduled_tokens["1"] == 8 # num_computed_tokens should have been updated immediately. @@ -329,42 +327,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): assert engine_core.scheduler.get_num_unfinished_requests() == 2 - # Batch queue is full. Finish Batch 1. - engine_core.step_with_batch_queue() - - # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled + # Finish Batch 1 and schedule Batch 3: (4, req1). + # Note that req0 cannot be scheduled # because it is in the decoding stage now. engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["1"] == 4 - # Batch queue is full. Finish Batch 2. Get first token of req0. + # Finish Batch 2. Get first token of req0. + # Schedule Batch 4: (1, req0). output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 - - # Schedule Batch 4: (1, req0). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 1 - # Batch queue is full. Finish Batch 3. Get first token of req1. + # Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1). output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 - - # Schedule Batch 5: (1, req1). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["1"] == 1 # Loop until req0 is finished. - step = 0 req_id = 0 expected_num_tokens = [ engine_core.scheduler.requests["0"].num_tokens + 1, @@ -372,19 +360,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ] while engine_core.scheduler.get_num_unfinished_requests() == 2: output = engine_core.step_with_batch_queue()[0] - if step % 2 == 0: - # Even steps consumes an output. - assert output is not None - assert len(output[0].outputs) == 1 - if req_id in engine_core.scheduler.requests: - assert engine_core.scheduler.requests[ - req_id].num_tokens == expected_num_tokens[req_id] - expected_num_tokens[req_id] += 1 - req_id = (req_id + 1) % 2 - else: - # Odd steps schedules a new batch. - assert output is None - step += 1 + # Every step consumes an output. + assert output is not None + assert len(output[0].outputs) == 1 + if req_id in engine_core.scheduler.requests: + assert engine_core.scheduler.requests[ + req_id].num_tokens == expected_num_tokens[req_id] + expected_num_tokens[req_id] += 1 + req_id = (req_id + 1) % 2 @multi_gpu_test(num_gpus=2) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 1bac426ab9984d3b8b536f3f8589bcecf6ed7c63..cd74aa1d5cbc461dc43005384c1fc5148b107cc2 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -54,9 +54,7 @@ def make_request( return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=prompt_tokens_ids, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, sampling_params=params, pooling_params=None, eos_token_id=None, diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index 270d8fae83c5c65cf6e51638744ef0fe75374bee..49d80247bcbc56d681db6c90cd9b06c85fbaf3f4 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -28,16 +28,14 @@ def test_fast_inc_detok_invalid_utf8_err_case(): prompt_token_ids = [107, 4606, 236787, 107] params = SamplingParams(skip_special_tokens=True) request = EngineCoreRequest( - "test", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, + request_id="test", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None, ) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 9caab083ffecd6e7e9f3776390d4d3efcffde974..b96bba2e50e0f4845cab31dc97c47d9dea854cea 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -54,11 +54,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -403,11 +401,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -568,11 +564,9 @@ def test_stop_token(include_stop_str_in_output: bool, request = EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=eos_token_id, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -667,11 +661,9 @@ def test_stop_string(include_stop_str_in_output: bool, EngineCoreRequest( request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -783,11 +775,9 @@ def test_iteration_stats(dummy_test_vectors): EngineCoreRequest( request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py new file mode 100644 index 0000000000000000000000000000000000000000..970a59eca8ece42ad9983dd9a378daf6cdc29d3f --- /dev/null +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig +from vllm.platforms.interface import UnspecifiedPlatform +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import processor as processor_mod +from vllm.v1.engine.processor import Processor + +cherry_pil_image = ImageAsset("cherry_blossom").pil_image +stop_pil_image = ImageAsset("stop_sign").pil_image +baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays + + +# Mock processor for testing +def _mk_processor(monkeypatch, + *, + mm_cache_gb: float = 4.0, + enable_prefix_caching: bool = True) -> Processor: + """ + Create a Processor instance with minimal configuration suitable for unit + tests without accessing external resources. + """ + monkeypatch.setattr(ModelConfig, + "try_get_generation_config", + lambda self: {}, + raising=True) + monkeypatch.setattr(ModelConfig, + "__post_init__", + lambda self: None, + raising=True) + monkeypatch.setattr(UnspecifiedPlatform, + "is_async_output_supported", + classmethod(lambda cls, enforce_eager: True), + raising=True) + monkeypatch.setattr( + ModelConfig, + "verify_async_output_proc", + lambda self, parallel_config, speculative_config, device_config: None, + raising=True) + monkeypatch.setattr(ModelConfig, + "verify_with_parallel_config", + lambda self, parallel_config: None, + raising=True) + monkeypatch.setattr(processor_mod, + "processor_cache_from_config", + lambda vllm_config, mm_registry: None, + raising=True) + + monkeypatch.setattr(VllmConfig, + "__post_init__", + lambda self: None, + raising=True) + + model_config = ModelConfig( + skip_tokenizer_init=True, + max_model_len=128, + mm_processor_cache_gb=mm_cache_gb, + generation_config="vllm", + tokenizer="dummy", + ) + + # Minimal multimodal_config to satisfy references in + # Processor.process_inputs. + class _MockMMConfig: + + def __init__(self, gb: float): + self.mm_processor_cache_gb = gb + + model_config.multimodal_config = _MockMMConfig( + mm_cache_gb) # type: ignore[attr-defined] + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), + device_config=DeviceConfig(device="cpu"), + ) + + # Pass tokenizer=None; InputPreprocessor handles None when + # skip_tokenizer_init is True. + return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type] + + +def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): + processor = _mk_processor(monkeypatch) + + prompt = { + "prompt": "USER: <image>\nDescribe\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image] + }, + # Mismatch: 2 items but only 1 uuid provided + "multi_modal_uuids": { + "image": ["hash_cherry"] + }, + } + + with pytest.raises(ValueError, match="must have same length as data"): + processor.process_inputs( + request_id="req-1", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + +def test_multi_modal_uuids_missing_modality_raises(monkeypatch): + processor = _mk_processor(monkeypatch) + + prompt = { + "prompt": "USER: <image><video>\nDescribe\nASSISTANT:", + # Two modalities provided in data + "multi_modal_data": { + "image": [cherry_pil_image], + "video": [baby_reading_np_ndarrays] + }, + # Only image uuids provided; video missing should raise + "multi_modal_uuids": { + "image": ["hash_cherry"] + }, + } + + with pytest.raises(ValueError, + match="must be provided if multi_modal_data"): + processor.process_inputs( + request_id="req-2", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + +@pytest.mark.parametrize( + "mm_cache_gb, enable_prefix_caching", + [ + (4.0, True), # default behavior + (4.0, False), # prefix caching disabled + (0.0, True), # processor cache disabled + ], +) +def test_multi_modal_uuids_accepts_none_and_passes_through( + monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool): + processor = _mk_processor(monkeypatch, + mm_cache_gb=mm_cache_gb, + enable_prefix_caching=enable_prefix_caching) + + # Capture the overrides passed to InputPreprocessor.preprocess + captured: dict[str, object] = {} + + def fake_preprocess(prompt, + *, + tokenization_kwargs=None, + lora_request=None, + mm_hash_overrides=None): + captured["mm_hash_overrides"] = mm_hash_overrides + # Minimal processed inputs for decoder-only flow + return {"type": "token", "prompt_token_ids": [1]} + + # Monkeypatch only the bound preprocess method on this instance + monkeypatch.setattr(processor.input_preprocessor, + "preprocess", + fake_preprocess, + raising=True) + + # Use a consistent two-image scenario across all configurations + mm_uuids = {"image": [None, "hash_stop"], "video": None} + prompt = { + "prompt": "USER: <image><image>\nTwo images\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image], + "video": baby_reading_np_ndarrays, + }, + "multi_modal_uuids": mm_uuids, + } + + processor.process_inputs( + request_id="req-3", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + assert captured["mm_hash_overrides"] == mm_uuids + + +def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): + # When both processor cache is 0 and prefix caching disabled, the + # processor builds overrides from request id instead of using user UUIDs. + processor = _mk_processor(monkeypatch, + mm_cache_gb=0.0, + enable_prefix_caching=False) + + captured: dict[str, object] = {} + + def fake_preprocess(prompt, + *, + tokenization_kwargs=None, + lora_request=None, + mm_hash_overrides=None): + captured["mm_hash_overrides"] = mm_hash_overrides + return {"type": "token", "prompt_token_ids": [1]} + + monkeypatch.setattr(processor.input_preprocessor, + "preprocess", + fake_preprocess, + raising=True) + + request_id = "req-42" + mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"} + prompt = { + "prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image], + "video": baby_reading_np_ndarrays, + }, + "multi_modal_uuids": mm_uuids, + } + + processor.process_inputs( + request_id=request_id, + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + # Expect request-id-based overrides are passed through + assert captured["mm_hash_overrides"] == { + "image": [f"{request_id}-image-0", f"{request_id}-image-1"], + "video": [f"{request_id}-video-0"], + } diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 6daad4ff803b8a296d7b3b5c090ee76f8409748a..d433a18652368b7c4024e3a6741012dfe907a560 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -12,9 +12,11 @@ import os import jsonschema import pytest import regex as re +import torch from pydantic import BaseModel from tests.reasoning.utils import run_reasoning_extraction +from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -41,8 +43,11 @@ EAGLE_SPEC_CONFIG = { PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "xgrammar", "auto", None), (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "guidance", "auto", None), + (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "lm-format-enforcer", "auto", + None), (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "xgrammar", "mistral", None), - (os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct"), "xgrammar", "auto", None), + (os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct", "xgrammar"), "auto", None), + (os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct"), "lm-format-enforcer", "auto", None), (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "outlines", "auto", None), (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "outlines", "mistral", None), (os.path.join(models_path_prefix, "mistralai/Ministral-8B-Instruct-2410"), "outlines", "auto", @@ -129,13 +134,15 @@ def test_structured_output( temperature=1.0, max_tokens=4096, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) - outputs = llm.generate(prompts=[ - (f"Give an example JSON for an employee profile that fits this " - f"schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + + prompt = ("Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + outputs = llm.generate( + [prompt] * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -146,7 +153,8 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - assert "\n" not in generated_text + if guided_decoding_backend != 'lm-format-enforcer': + assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @@ -193,20 +201,24 @@ def test_structured_output( with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): + + prompt = (f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") llm.generate( - prompts=[(f"Give an example JSON for an employee profile that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.")] * 2, + [prompt] * 2, sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) else: - outputs = llm.generate(prompts=( - "Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + prompt = (f"Give an example JSON object for a grade that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") + outputs = llm.generate( + prompt, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -219,7 +231,7 @@ def test_structured_output( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 4: Generate SQL statement using EBNF grammar # @@ -229,10 +241,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -263,10 +274,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -303,7 +313,6 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - prompts= ("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short " "as possible."), @@ -318,11 +327,11 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(regex=sample_regex)) + + prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible.") outputs = llm.generate( - prompts=[ - (f"Give an example IPv4 address with this regex: {sample_regex}. " - f"Make the response as short as possible.") - ] * 2, + [prompt] * 2, sampling_params=sampling_params, use_tqdm=True, ) @@ -345,11 +354,13 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) + outputs = llm.generate( - prompts=("The best language for type-safe systems programming is " - "(Make the response as short as possible.) "), + ("The best language for type-safe systems programming is " + "(Make the response as short as possible.) "), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -369,12 +380,14 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) - outputs = llm.generate(prompts=( - "Generate a JSON with the brand, model and car_type of the most " - "iconic car from the 90's. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + + outputs = llm.generate( + ("Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -413,10 +426,11 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( - prompts=("Generate a description of a frog using 50 characters. " - "Make the response as short as possible."), + ("Generate a description of a frog using 50 characters. " + "Make the response as short as possible."), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None @@ -431,7 +445,7 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 11: Generate structured output using structural_tag format # @@ -500,7 +514,7 @@ Make the response as short as possible. """ # Change this once other backends support structural_tag - outputs = llm.generate(prompts=prompt, + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None @@ -641,15 +655,13 @@ def test_structured_output_auto_mode( f"{unsupported_json_schema}. Make the response as short as possible.") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts=prompts, + outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) # Make sure `auto` backend handling doesn't mess up sampling_params # and that we can reuse it without error. outputs.extend( - llm.generate(prompts=prompts, - sampling_params=sampling_params, - use_tqdm=True)) + llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)) assert outputs is not None for output in outputs: @@ -707,7 +719,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): max_tokens=256, guided_decoding=guided_params) - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) assert outputs is not None generated_text = outputs[0].outputs[0].text assert generated_text is not None @@ -722,4 +734,84 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): assert "a3" in generated assert "a4" not in generated assert "a5" not in generated - assert "a6" not in generated \ No newline at end of file + assert "a6" not in generated + + +@pytest.mark.parametrize("guided_decoding_backend", + ["guidance", "xgrammar", "outlines"]) +def test_structured_output_batched_with_non_guided_requests( + monkeypatch: pytest.MonkeyPatch, + sample_json_schema: dict[str, Any], + guided_decoding_backend: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager = bool(not current_platform.is_tpu()) + + llm = LLM( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + enforce_eager=enforce_eager, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=(guided_decoding_backend + in {"xgrammar", "guidance"}), + ) + + guided_prompt = ( + "Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + + non_guided_prompt = "The diameter of the Earth in kilometers is " + + prompts = [guided_prompt, non_guided_prompt] + sampling_params = [ + SamplingParams( + temperature=1.0, + max_tokens=400, + guided_decoding=GuidedDecodingParams(json=sample_json_schema)), + # No max tokens, temp=0 to assert on contents + SamplingParams( + seed=42, + temperature=0, + top_p=1.0, + ), + ] + + outputs = llm.generate(prompts=prompts, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + # Free memory as soon as possible as failed assertions + # will short circuit and not free up memory + del llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + for index, output in enumerate(outputs): + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}") + + if index == 0: + # First prompt is guided, expect valid JSON + assert "\n" not in generated_text + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_json_schema) + else: + # Second prompt is not guided, expect valid output + # Cannot assert on exact output, but we can expect it to be factual + assert "12,742" in generated_text + + # non-guided requests should not return a valid JSON here + with pytest.raises(ValueError): + output_json = json.loads(generated_text) diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py index 18c35152e7b20c70903fd7fa5760d6aa3c54a996..7a0baa5767cbad92a05a889601da955f28f5cd77 100644 --- a/tests/v1/entrypoints/openai/responses/test_basic.py +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI): ], ) print(response) assert response.status == "completed" + + +@pytest.mark.asyncio +async def test_logprobs(client: openai.AsyncOpenAI): + response = await client.responses.create( + include=["message.output_text.logprobs"], + input="What is 13 * 24?", + top_logprobs=5, + ) + print(response) + outputs = response.output + assert outputs[-1].content[-1].logprobs + assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5 diff --git a/tests/prefix_caching/__init__.py b/tests/v1/executor/__init__.py similarity index 100% rename from tests/prefix_caching/__init__.py rename to tests/v1/executor/__init__.py diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd5155c1481d55dc1a21dee5caaab0bb5843f19 --- /dev/null +++ b/tests/v1/executor/test_executor.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import os +from typing import Any, Callable, Optional, Union + +import pytest + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.sampling_params import SamplingParams +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.executor.multiproc_executor import MultiprocExecutor + + +class Mock: + ... + + +class CustomMultiprocExecutor(MultiprocExecutor): + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None) -> list[Any]: + # Drop marker to show that this was ran + with open(".marker", "w"): + ... + return super().collective_rpc(method, timeout, args, kwargs) + + +CustomMultiprocExecutorAsync = CustomMultiprocExecutor +MODEL = "Qwen/Qwen3-0.6B" + + +def test_custom_executor_type_checking(): + with pytest.raises(ValueError): + engine_args = EngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock, + ) + LLMEngine.from_engine_args(engine_args) + with pytest.raises(ValueError): + engine_args = AsyncEngineArgs(model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock) + AsyncLLM.from_engine_args(engine_args) + + +@pytest.mark.parametrize("distributed_executor_backend", [ + CustomMultiprocExecutor, + "tests.v1.executor.test_executor.CustomMultiprocExecutor" +]) +def test_custom_executor(distributed_executor_backend, tmp_path): + cwd = os.path.abspath(".") + os.chdir(tmp_path) + try: + assert not os.path.exists(".marker") + + engine_args = EngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, # reduce test time + ) + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + engine.add_request("0", "foo", sampling_params) + engine.step() + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) + + +@pytest.mark.parametrize("distributed_executor_backend", [ + CustomMultiprocExecutorAsync, + "tests.v1.executor.test_executor.CustomMultiprocExecutorAsync" +]) +def test_custom_executor_async(distributed_executor_backend, tmp_path): + cwd = os.path.abspath(".") + os.chdir(tmp_path) + try: + assert not os.path.exists(".marker") + + engine_args = AsyncEngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, # reduce test time + ) + engine = AsyncLLM.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + async def t(): + stream = engine.generate(request_id="0", + prompt="foo", + sampling_params=sampling_params) + async for x in stream: + ... + + asyncio.run(t()) + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e6859ea7382771f26265b3ad579952d19585d0e7..040b44dc5d2ca41794e5dc77b75e22a1550418ad 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -14,6 +14,7 @@ from unittest.mock import patch import pytest import ray +import torch from vllm import LLM from vllm.config import KVTransferConfig @@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorWorker) from vllm.forward_context import ForwardContext from vllm.sampling_params import SamplingParams +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from .utils import create_request, create_scheduler, create_vllm_config @@ -98,7 +100,6 @@ class FakeNixlWrapper: def set_cycles_before_xfer_done(self, cycles: int): """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles @contextlib.contextmanager @@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): sampling_params) # Request-0 times out and is cleared! assert '0' not in req_to_blocks + + +def test_register_kv_caches(dist_init): + """ + Test that register_kv_caches() properly calls nixl_wrapper methods with + correct data. + + This test verifies: + 1. nixl_wrapper.get_reg_descs() is called with caches_data containing + tensor metadata + 2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing + block layout info + """ + + vllm_config = create_vllm_config() + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2, + block_size=16, + num_kv_heads=4, + head_size=64) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + expected_tensor_size = shared_tensor[0].element_size( + ) * shared_tensor[0].numel() + expected_base_addrs = [ + shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr() + ] + + with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501 + + # Create connector + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + + # Get the mock instance + mock_wrapper_instance = mock_nixl_wrapper.return_value + connector.connector_worker.nixl_wrapper = mock_wrapper_instance + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify get_reg_descs was called with caches_data + assert mock_wrapper_instance.get_reg_descs.called + caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] + assert len(caches_data) == 4 + + for i, cache_entry in enumerate(caches_data): + base_addr, size, _tp_rank, _ = cache_entry + assert size == expected_tensor_size, \ + f"Entry {i}: Expected tensor size {expected_tensor_size}, " \ + f"got {size}" + assert base_addr == expected_base_addrs[i], \ + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \ + f"got {base_addr}" + + # Verify get_xfer_descs was called with blocks_data + assert mock_wrapper_instance.get_xfer_descs.called + blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] + + # Validate blocks_data structure and size + expected_blocks_count = 8 + assert len(blocks_data) == expected_blocks_count, \ + f"Expected {expected_blocks_count} blocks, " \ + f"got {len(blocks_data)}" + + expected_block_len = expected_tensor_size // 2 + for i, block_entry in enumerate(blocks_data): + block_start_addr, block_len, tp_rank = block_entry + assert block_len == expected_block_len, \ + f"Block entry {i}: Expected block len {expected_block_len}, " \ + f"got {block_len}" diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 8c5d132c00ae4b2ba8842fb2c8bd3f9263db68aa..3f068d5e8c7eb3bcc92ec656efe1d8686de05a1d 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -162,9 +162,7 @@ def create_request(request_id: int, prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, - multi_modal_kwargs=None, - multi_modal_placeholders=None, - multi_modal_hashes=None, + mm_features=None, eos_token_id=EOS_TOKEN_ID, block_hasher=get_request_block_hasher(block_size, hash_fn), ) @@ -200,7 +198,6 @@ def create_model_runner_output( req_ids=req_ids, req_id_to_index=req_id_to_index, sampled_token_ids=sampled_token_ids, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=None, diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index c0bfc1a18feca201f204cac6d81cf3d2c2ada00b..c36f1bd021c705ab183c341cfa7216e077723bed 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -8,10 +8,9 @@ from typing import Optional import torch from vllm.config import VllmConfig -from vllm.sampling_params import SamplingParams from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, - LogitsProcessor, - MoveDirectionality) + LogitsProcessor) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates MODEL_NAME = "facebook/opt-125m" POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" @@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor): def __init__(self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool): - self.req_info: dict[int, SamplingParams] = {} + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and (target_token := - params.extra_args.get("target_token")): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + process_dict_updates( + self.req_info, + batch_update, + lambda params, _, __: params.extra_args and + (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 15f7c81387b3b52bdda451dc4893be8d971fb55e..c7c9a69ce86a12020b3f46765b9bb21ac773a93b 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -458,9 +458,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): assert len(logprob) == vocab_size -@pytest.mark.parametrize( - "logprobs_mode", - ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) +@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch): """Test with LLM engine with different logprobs_mode. @@ -489,12 +487,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, for logprobs in output.logprobs: for token_id in logprobs: logprob = logprobs[token_id] - if "logprobs" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGPROBS, + LogprobsMode.PROCESSED_LOGPROBS): assert logprob.logprob <= 0 if logprob.logprob > 0: positive_values = positive_values + 1 total_token_with_logprobs = total_token_with_logprobs + 1 assert total_token_with_logprobs >= len(results[0].outputs) - if "logits" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGITS, + LogprobsMode.PROCESSED_LOGITS): assert positive_values > 0 del llm diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 456ce712d36e4608ee12cc687562f03f315abcce..63178174086611156bc0e21f225aa50716c5dc6b 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -50,6 +50,7 @@ def forward_attention( dtype=torch.int32, ) context_lens = seq_lens - query_lens + max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] @@ -81,6 +82,7 @@ def forward_attention( num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table, slot_mapping=slot_mapping, ) diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index 15462d698b757fcd1359f05709e9c056509eb52c..de6ebc0b09931f6bd963c04071f241f51d461a4e 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -78,9 +78,10 @@ async def generate( ], ) @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) +@pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind, - data_parallel_backend: str): +async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str, + async_scheduling: bool): stats_loggers = {} @@ -108,6 +109,7 @@ async def test_load(output_kind: RequestOutputKind, prompt = "This is a test of data parallel" engine_args.data_parallel_backend = data_parallel_backend + engine_args.async_scheduling = async_scheduling engine = AsyncLLM.from_engine_args(engine_args, stat_loggers=[SimpleStatsLogger]) after.callback(engine.shutdown) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 586276ee08aef18a8dea0fcd41b1a1dc890edb01..118b40d0ef4188e0e914a4ac9f0041f1598a2e8f 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -11,7 +11,8 @@ import torch from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalFlatField, - MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): class MyRequest(msgspec.Struct): - mm: Optional[list[MultiModalKwargs]] + mm: Optional[list[MultiModalKwargsItems]] def test_multimodal_kwargs(): @@ -119,7 +120,7 @@ def test_multimodal_kwargs(): audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) - mm = MultiModalKwargs([audio, video, image]) + mm = MultiModalKwargsItems.from_seq([audio, video, image]) # pack mm kwargs into a mock request so that it can be decoded properly req = MyRequest([mm]) @@ -133,19 +134,22 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 14255, +-20 for minor changes - assert 14250 <= total_len <= 14300 - decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + # expected total encoding length, should be 14306, +-20 for minor changes + assert 14275 <= total_len <= 14325 + decoded = decoder.decode(encoded).mm[0] + assert isinstance(decoded, MultiModalKwargsItems) # check all modalities were recovered and do some basic sanity checks - assert len(decoded.modalities) == 3 - images = decoded.get_items("image") + assert len(decoded) == 3 + images = decoded["image"] assert len(images) == 1 assert len(images[0].items()) == 2 assert list(images[0].keys()) == ["i0", "i1"] # check the tensor contents and layout in the main dict - assert all(nested_equal(mm[k], decoded[k]) for k in mm) + mm_data = mm.get_data() + decoded_data = decoded.get_data() + assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data) def nested_equal(a: NestedTensors, b: NestedTensors): diff --git a/tests/v1/tpu/worker/untest_tpu_model_runner.py b/tests/v1/tpu/worker/untest_tpu_model_runner.py index 5a05781a03f2a4bef94bf073dc780667967e4282..941aa0a77692c4a870a9f7d6344e11578268b9a2 100644 --- a/tests/v1/tpu/worker/untest_tpu_model_runner.py +++ b/tests/v1/tpu/worker/untest_tpu_model_runner.py @@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index d7b4746562beb9d7120519922b70efaeb0bb59ef..70318590782645977cd3621b5d15c4b1d1c99265 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int): pooling_params=None, mm_kwargs=[], mm_positions=[], + mm_hashes=[], block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 78756f1998cdb84f8c9856d397e47b560fe916ed..fe4fb73a15d203c4197827cf5e4b65a8bd6a40f1 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -143,7 +143,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -209,7 +209,7 @@ def test_update_states_request_finished(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -241,7 +241,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -268,7 +268,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -349,7 +349,7 @@ def test_update_states_no_changes(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -386,7 +386,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -682,6 +682,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) + kv_cache_config_after_init = runner.kv_cache_config layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache[0] @@ -689,10 +690,12 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert id(layer_1_kv) == id(layer_0_kv) # check layer 1 added to kv cache group's layer names - assert len(kv_cache_config.kv_cache_groups) == 1 - assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 - assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + assert len(kv_cache_config_after_init.kv_cache_groups) == 1 + assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 0] == layer_0 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 1] == layer_1 def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1b797074096ed928c61ea008f3f1d853f1e087ae..cc18c9ff1f096d94f006ccb5fa9f768bcc4fff00 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -26,9 +26,5 @@ compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main -marlin, nm-testing/zephyr-beta-7b-marlin-g128, main -marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main -qqq, HandH1998/QQQ-Llama-3-8b-g128, main -qqq, HandH1998/QQQ-Llama-3-8b, main hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main None, mgleize/fairseq2-dummy-Llama-3.2-1B, main \ No newline at end of file diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 2031f41fab87d061a9a67b2511a948bdfc3fa63c..0f28ef2ba857b19933b1f50dee85dc275a00f806 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata -from vllm.worker.pooling_model_runner import ( - ModelInputForGPUWithPoolingMetadata) class MockAttentionBackend(AttentionBackend): @@ -114,54 +111,3 @@ def test_model_runner_input(): assert (received_model_input.sampling_metadata.selected_token_indices == sampling_metadata.selected_token_indices) assert received_model_input.sampling_metadata.seq_groups is None - - -def test_embedding_model_runner_input(): - pooling_metadata = PoolingMetadata( - seq_groups=[[0]], - seq_data={}, - prompt_lens=[1], - ) - attn_metadata = AttentionMetadata( - num_prefills=1, - num_prefill_tokens=2, - num_decode_tokens=3, - slot_mapping=torch.zeros(1), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - ) - model_input = ModelInputForGPUWithPoolingMetadata( - input_tokens=torch.ones(10), - input_positions=torch.ones(10), - pooling_metadata=pooling_metadata, - attn_metadata=attn_metadata) - - assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata) - - # Test round trip serialization. - tensor_dict = model_input.as_broadcastable_tensor_dict() - attn_backend = MockAttentionBackend() - received_model_input = ( - ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend)) - # Check that received copy has correct values. - assert isinstance(received_model_input, - ModelInputForGPUWithPoolingMetadata) - assert received_model_input.input_tokens is not None - assert ( - received_model_input.input_tokens == model_input.input_tokens).all() - assert received_model_input.input_positions is not None - assert (received_model_input.input_positions == model_input.input_positions - ).all() - assert received_model_input.multi_modal_kwargs is None - assert (received_model_input.multi_modal_kwargs == - model_input.multi_modal_kwargs) - assert received_model_input.lora_requests is None - assert received_model_input.lora_requests == model_input.lora_requests - assert received_model_input.lora_mapping is None - assert received_model_input.lora_mapping == model_input.lora_mapping - for field in dataclasses.fields(AttentionMetadata): - assert getattr(received_model_input.attn_metadata, field.name, - None) == getattr(attn_metadata, field.name, None) - # Pooling metadata is not broadcast. - assert received_model_input.pooling_metadata is None diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index 444e2bf53f995de9e9796d783faff73debd0e4cc..ad0ae45d1d465b961c9adcc66ac8c3605709a3c5 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -37,7 +37,7 @@ ALLOWED_FILES = set([ 'vllm/distributed/utils.py', 'vllm/distributed/parallel_state.py', 'vllm/engine/multiprocessing/client.py', - 'vllm/distributed/device_communicators/custom_all_reduce_utils.py', + 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index e163c83e8b51308038348e88ce47a2714ea9fdea..59bfe69dc0dd6b453295b3b098321a07b2dff457 100644 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -77,6 +77,7 @@ clone_repo() { local repo_url=$1 local dir_name=$2 local key_file=$3 + local commit_hash=$4 if [ -d "$dir_name" ]; then # Check if directory has uncommitted changes (dirty) @@ -87,17 +88,27 @@ clone_repo() { echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning" rm -rf "$dir_name" git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi else echo "$dir_name directory exists and appears complete; manually update if needed" fi else git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi fi } # build and install pplx, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" +clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" "c336faf" cd pplx-kernels # see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 # PIP_NO_BUILD_ISOLATION=0 disables build isolation @@ -106,7 +117,7 @@ popd # build and install deepep, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" +clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "e3908bf" cd DeepEP export NVSHMEM_DIR=$WORKSPACE/nvshmem_install PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh new file mode 100755 index 0000000000000000000000000000000000000000..33849581d2c0e0d29fbb5f5f4d2a31d261174a51 --- /dev/null +++ b/tools/install_deepgemm.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Script to install DeepGEMM from source +# This script can be used both in Docker builds and by users locally + +set -e + +# Default values +DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" +DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --ref) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --ref requires an argument." >&2 + exit 1 + fi + DEEPGEMM_GIT_REF="$2" + shift 2 + ;; + --cuda-version) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --cuda-version requires an argument." >&2 + exit 1 + fi + CUDA_VERSION="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)" + echo " --cuda-version VER CUDA version (auto-detected if not provided)" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +# Auto-detect CUDA version if not provided +if [ -z "$CUDA_VERSION" ]; then + if command -v nvcc >/dev/null 2>&1; then + CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p') + echo "Auto-detected CUDA version: $CUDA_VERSION" + else + echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version" + exit 1 + fi +fi + +# Extract major and minor version numbers +CUDA_MAJOR="${CUDA_VERSION%%.*}" +CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" +CUDA_MINOR="${CUDA_MINOR%%.*}" + +echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)" + +# Check CUDA version requirement +if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then + echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" + exit 0 +fi + +echo "Installing DeepGEMM from source..." +echo "Repository: $DEEPGEMM_GIT_REPO" +echo "Reference: $DEEPGEMM_GIT_REF" + +# Create a temporary directory for the build +INSTALL_DIR=$(mktemp -d) +trap 'rm -rf "$INSTALL_DIR"' EXIT + +# Clone the repository +git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm" + +echo "🏗️ Building DeepGEMM" +pushd "$INSTALL_DIR/deepgemm" + +# Checkout the specific reference +git checkout "$DEEPGEMM_GIT_REF" + +# Build DeepGEMM +# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) +rm -rf build dist +rm -rf *.egg-info +python3 setup.py bdist_wheel + +# Install the wheel +if command -v uv >/dev/null 2>&1; then + echo "Installing DeepGEMM wheel using uv..." + # Use --system in Docker contexts, respect user's environment otherwise + if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then + uv pip install --system dist/*.whl + else + uv pip install dist/*.whl + fi +else + echo "Installing DeepGEMM wheel using pip..." + python3 -m pip install dist/*.whl +fi + +popd + +echo "✅ DeepGEMM installation completed successfully" \ No newline at end of file diff --git a/tools/profiler/nsys_profile_tools/README.md b/tools/profiler/nsys_profile_tools/README.md index 75ae0811cc5434fd7a94475ff28eb63b90deacd5..9577efb68fb4b35934d83ca6ef284866896a7347 100644 --- a/tools/profiler/nsys_profile_tools/README.md +++ b/tools/profiler/nsys_profile_tools/README.md @@ -36,8 +36,7 @@ profiling and analyzing nsys profile output. ## Notes - Make sure you have pandas installed. -- Make sure nsys is installed, and specify the path to the `nsys` command with - `--nsys_cmd` if it is not in your PATH. +- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is installed, and specify the path to the `nsys` command with `--nsys_cmd` if it is not in your PATH. - For more details on available engines and models, see the help string in the script or run: @@ -135,34 +134,31 @@ time which would cause a difference for the overall category. ## Example 3: add new classification for a new model -Suppose there's a new model ABC that is available for engine DEF, and say there -are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels +To create a new engine DEF with model ABC, just add another json file in the same directory as +gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications. + +Then, for this new model, suppose there are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*" -or "*K*" in them, add a new entry like so: - -```python -engine_model = { - 'DEF': { - 'ABC': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'H|I': 'gemm', - 'J|K': 'attn', - 'CUDA mem': 'non-gpu-H_D_memops', - '.*': 'misc' - } - } - }, - } - 'vllm': {...} +or "*K*" in them, just add another .json file in the same directory as +gputrc2graph.py with the same format as the other json files, like the following: + +```json +{ + "DEF": { + "ABC": { + "H|I": "gemm", + "J|K": "attn", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} ``` -Basically Substage is a dictionary with a list of key/value pairs, where the -keys are regex's of the kernel names to be classified, and values are the -classification bins which one wishes to compare across engines/models. +Each entry in the dictionary consists of: + +- key: a regex used to classify the kernels +- value: the category to classify the kernels into. The last 2 entries are common for all engine/models, consisting of CUDA memory operations and a 'misc' for anything that's leftover and can't be classified. @@ -173,3 +169,6 @@ like the following: ```bash --infile new.nsys-rep,DEF,ABC,<runtime> ``` + +If the engine_DEF.json file already exists, just add the model as a new node in +the existing engine file, after the other models. diff --git a/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tools/profiler/nsys_profile_tools/gputrc2graph.py index 8921e1f20f3dafd9afb0b2745ed70d99ce059cbf..42dfede9e98705efc46fbc4c8d74a3ffe1c4bb69 100755 --- a/tools/profiler/nsys_profile_tools/gputrc2graph.py +++ b/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -15,132 +15,18 @@ logger = logging.getLogger(__name__) # helper data class for annotating kernels -class EngineModelData: - # engine + model mappings - engine_model = { - 'vllm': { - 'llama': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'gemm': 'gemm', - 'fused_moe_kernel|GroupProblemShape|group_gemm_starts': - 'moe_gemm', #llama4 - 'moe|sigmoid': 'moe', #llama4 - 'CatArrayBatched|prepare_inputs': 'prepare_next', - 'flash': 'attn', - 'ncclDevKernel|cross_device_reduce': - 'nccl_and_custom_ar', - '_norm_': 'norm', - 'act_and_mul_': 'silu', - 'rotary_embedding_kernel': 'rope', - 'SoftMax': 'softmax', - 'elementwise': 'elementwise', - 'fp8_quant': 'quantize', - 'reduce_kernel': 'reduce', - 'triton': 'triton_kernel', - 'CUDA mem': 'non-gpu-H_D_memops', - '.*': 'misc' - } - } - }, - 'ds': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'block_fp8|gemm_fp8_blockwise': - 'block_fp8_gemm', - 'fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal': - 'moe_gemm', - 'gemm|matmul|nvjet': - 'gemm', - 'moe|sigmoid|expert': - 'moe', - '_fwd_|FlashAttn|_mla_|_attn_': - 'attn', - 'CatArrayBatched': - 'prepare_next', - 'ncclDevKernel|cross_device_reduce': - 'nccl_and_custom_ar', - 'Norm|_norm_': - 'norm', - 'sbtopk': - 'topk', - 'act_and_mul_': - 'activation', - 'compute_position_kernel': - 'rope', - 'elementwise': - 'elementwise', - 'fp8_quant|quant_fp8|cvt_fp16_to_fp4': - 'quantize', - 'reduce': - 'reduce', - 'SoftMax': - 'softmax', - 'triton': - 'triton_kernel', - 'CUDA mem': - 'non-gpu-H_D_memops', - '.*': - 'misc' - } - } - }, - 'gpt-oss': { - 'layer_anno': { - 'Stage': { - '.*': 'layer', - }, - 'Substage': { - 'block_fp8|gemm_fp8_blockwise': - 'block_fp8_gemm', - 'fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_' - # this section is triton_moe_gemm - '|matmul_ogs_|_topk_forward|_combined_routing' - '|_sum_bitmatrix_rows|_compute_writeback_idx': - 'moe_gemm', - 'gemm|matmul|nvjet': - 'gemm', - 'moe|sigmoid|expert|splitKreduce': - 'moe', - '_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha': - 'attn', - 'CatArrayBatched': - 'prepare_next', - 'ncclDevKernel|cross_device_reduce': - 'nccl_and_custom_ar', - 'Norm|_norm_': - 'norm', - 'sbtopk': - 'topk', - 'act_and_mul_': - 'activation', - 'compute_position_kernel': - 'rope', - 'elementwise': - 'elementwise', - 'fp8_quant|quant_fp8|cvt_fp16_to_fp4|quantize': - 'quantize', - 'reduce': - 'reduce', - 'SoftMax': - 'softmax', - 'triton': - 'triton_kernel', - 'CUDA mem': - 'non-gpu-H_D_memops', - '.*': - 'misc' - } - } - } - }, - } +def load_engine_model(): + """ returns engine_model built from all json files in the current dir """ + import glob + import json + engine_model = {} + + json_files = glob.glob( + os.path.join(os.path.dirname(__file__) or ".", "*.json")) + for fname in json_files: + with open(fname, encoding="utf-8") as f: + engine_model.update(json.load(f)) + return engine_model class GPUTrace2Graph: @@ -148,8 +34,7 @@ class GPUTrace2Graph: Parses output of nsys report, generates csv and bar chart output """ - def __init__(self, nsys_cmd): - self.nsys_cmd = nsys_cmd + def __init__(self): import pandas as pd # avoid importing till needed self.pd = pd self.pd.options.mode.copy_on_write = True @@ -227,7 +112,7 @@ class GPUTrace2Graph: title = 'Model_Engine' x = 'Model_Engine' y = 'Elapsed Time (sec)' - color = 'Substage' + color = 'Category' """ generate kernel mapping table """ # Sort Model_Engine categories by last field after underscore df['Model_Engine'] = self.pd.Categorical( @@ -249,14 +134,13 @@ class GPUTrace2Graph: Generate data table with columns per Model_Engine into result.html """ pivot_df = df.pivot_table(values='Elapsed Time (sec)', - index='Substage', + index='Category', columns='Model_Engine', aggfunc='sum', observed=False).round(2) # Add sum row at bottom pivot_df.loc['total_elapsed_sec'] = pivot_df.sum() pivot_df.fillna('').to_html('temp.html') - print('got') with (open(f'{output_name}.html', 'a', encoding='utf-8') as outfile, open('temp.html', encoding='utf-8') as infile): outfile.write(infile.read()) @@ -264,23 +148,22 @@ class GPUTrace2Graph: print(f'Finished generating: \n' f' {output_name}.html for stack bar chart \n' - f' {output_name}.csv for Kernel-Substage mapping') + f' {output_name}.csv for Kernel-Category mapping') def anno_gpu_kernname(self, df, mapping): - """ add "stage" and "substage" columns """ + """ add "Category" column """ - def anno_gpu_kernname_helper(name, stage): - for kern_name, val in mapping['layer_anno'][stage].items(): + def anno_gpu_kernname_helper(name): + for kern_name, val in mapping.items(): if re.search(kern_name, name): return val - for stage in ['Stage', 'Substage']: - df[stage] = df['Name'].apply(anno_gpu_kernname_helper, stage=stage) + df['Category'] = df['Name'].apply(anno_gpu_kernname_helper) def make_nongpu_row(self, df, nongpu_sec): """ this will append non-gpu time entry at end of df """ nongpu_row = self.pd.DataFrame([df.iloc[-1]]) - nongpu_row['Substage'] = nongpu_row['Name'] = 'CPU(non-GPU)' + nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)' nongpu_row['Instances'] = 1 nongpu_row['Elapsed Time (sec)'] = nongpu_sec return (nongpu_row) @@ -302,7 +185,7 @@ class GPUTrace2Graph: logger.info('generating %s', new_file) return True - def gen_sum_file(self, file): + def gen_sum_file(self, file, nsys_cmd): """ generates sum file from nsys trace with times per kernel and returns the name of the sum file @@ -318,17 +201,21 @@ class GPUTrace2Graph: sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv' if self.should_gen_file(nsys_stats_file, file): cmd = [ - self.nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', + nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', f'{file_dir}/{file_name}' ] cmd_str = ' '.join(cmd) logger.info('+ %s', cmd_str) + # estimate time based on calibrated 240M/min + file_size_mb = os.path.getsize(file) / 1e6 + logger.info( + 'nsys stats for %.2f MB file expected to take %.2f min', + file_size_mb, file_size_mb / 240) try: - subprocess.run(cmd) + subprocess.run(cmd, check=True) except Exception: - logger.error( - "%s failed, specify --nsys_cmd for correct nsys path", - cmd_str) + logger.error("%s failed; Use --nsys_cmd to specify nsys path", + cmd_str) exit(1) logger.info('generating non-overalapped sum %s', sum_file) self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) @@ -336,7 +223,7 @@ class GPUTrace2Graph: logger.info('Finished generating %s', sum_file) return sum_file - def gen_graph(self, in_file, out_dir, title): + def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): """ generates graph and csv file from in_file into out_dir """ # Initialize an empty DataFrame to store combined data combined_df = self.pd.DataFrame() @@ -345,17 +232,16 @@ class GPUTrace2Graph: file_name = os.path.basename(file) if not file_dir: file_dir = '.' - sum_file = self.gen_sum_file(file) + sum_file = self.gen_sum_file(file, nsys_cmd) # read kernel summary file df = self.pd.read_csv(sum_file) # annotate kernel to their categories - assert EngineModelData.engine_model.get(engine) - assert EngineModelData.engine_model[engine].get(model) + assert engine_model.get(engine), f'engine {engine} unknown' + assert engine_model[engine].get(model), f'model {model} unknown' # remove nsys-rep from file_name for shorter x-label file_name = file_name.replace('.nsys-rep', '') df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}' - self.anno_gpu_kernname(df, - EngineModelData.engine_model[engine][model]) + self.anno_gpu_kernname(df, engine_model[engine][model]) # patch in non-gpu time gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1) total_sec = round(float(total_sec), 1) @@ -393,12 +279,12 @@ def main(): "--out_dir results/ --title \"Model=gpt-oss vLLM chart\""), formatter_class=argparse.RawDescriptionHelpFormatter) - # Build help string showing available engine/model combinations - engine_model_help = [] - for engine, models in EngineModelData.engine_model.items(): - model_list = list(models.keys()) - engine_model_help.append(f"{engine}:[{','.join(model_list)}]") - engine_model_str = ' '.join(engine_model_help) + # load supported engine_model + engine_model_supported = load_engine_model() + # Get a string representation of supported engine/model combinations + engine_model_supported_str = ', '.join( + f"{engine}:[{', '.join(models.keys())}]" + for engine, models in engine_model_supported.items()) parser.add_argument( '--in_file', type=parse_tuple, @@ -408,7 +294,7 @@ def main(): 'separated by space. Elapsed_nonprofiled_sec is runtime without ' 'profiling used to calculate non-gpu time. Specify 0 to use ' 'elapsed time from nsys-rep but that might inflate non-gpu time. ' - f'Available engine:[model] are: {engine_model_str} ' + f'Available engine:[model] are: {engine_model_supported_str} ' f'Example: --infile d1.nsys-rep,vllm,llama,100 ' 'd2.nsys-rep,vllm,gpt-oss,102'), required=True) @@ -418,8 +304,9 @@ def main(): help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'), default="nsys") args = parser.parse_args() - gputrace = GPUTrace2Graph(args.nsys_cmd) - gputrace.gen_graph(args.in_file, args.out_dir, args.title) + gputrace = GPUTrace2Graph() + gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, + engine_model_supported) if __name__ == '__main__': diff --git a/tools/profiler/nsys_profile_tools/vllm_engine_model.json b/tools/profiler/nsys_profile_tools/vllm_engine_model.json new file mode 100644 index 0000000000000000000000000000000000000000..264c628dded345957cc454da4534fdd864bc03fa --- /dev/null +++ b/tools/profiler/nsys_profile_tools/vllm_engine_model.json @@ -0,0 +1,63 @@ +{ + "vllm": { + "llama": { + "fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm", + "gemm|nvjet": "gemm", + "moe|sigmoid": "moe", + "CatArrayBatched|prepare_inputs": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "_norm_|Norm": "norm", + "act_and_mul_": "activation", + "Rotary": "rope", + "SoftMax": "softmax", + "flash|fmha": "attn", + "elementwise": "elementwise", + "fp8_quant|cvt_": "quantize", + "reduce_kernel": "reduce", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "ds": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "sbtopk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|fmha": "attn", + "triton": "triton_kernel", + "topk": "topk", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "gpt-oss": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert|splitKreduce": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "topk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_|quantize": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha": "attn", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c9855fb80b8670b0324577abf0ca051f01d94d8d..e19b25ecf4d17462cbd6f9fb97d8533e65f5d7b5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -790,15 +790,7 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) -# # marlin -# def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, -# b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, -# size_n: int, size_k: int) -> torch.Tensor: -# return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, -# size_n, size_k) - - -# # marlin_24 +# marlin_24 # def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # b_meta: torch.Tensor, b_scales: torch.Tensor, # workspace: torch.Tensor, b_q_type: ScalarType, @@ -840,25 +832,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, # is_zp_float: bool = False) -> torch.Tensor: # return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) -# @register_fake("_C::marlin_qqq_gemm") -# def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, -# s_tok: torch.Tensor, s_ch: torch.Tensor, -# s_group: torch.Tensor, workspace: torch.Tensor, -# size_m: torch.SymInt, size_n: torch.SymInt, -# size_k: torch.SymInt) -> torch.Tensor: -# return torch.empty((size_m, size_n), -# dtype=torch.float16, -# device=a.device) - -# @register_fake("_C::marlin_gemm") -# def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, -# b_scales: torch.Tensor, workspace: torch.Tensor, -# size_m: torch.SymInt, size_n: torch.SymInt, -# size_k: torch.SymInt) -> torch.Tensor: -# return torch.empty((size_m, size_n), -# dtype=torch.float16, -# device=a.device) - # @register_fake("_C::awq_dequantize") # def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, # zeros: torch.Tensor, split_k_iters: torch.SymInt, @@ -904,6 +877,30 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, # return torch.empty_like(b_q_weight, # memory_format=torch.contiguous_format) +# @register_fake("_C::cutlass_w4a8_mm") +# def cutlass_w4a8_mm_fake( +# a: torch.Tensor, +# # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b +# b_q: torch.Tensor, +# b_group_scales: torch.Tensor, +# b_group_size: int, +# b_channel_scales: torch.Tensor, +# a_token_scales: torch.Tensor, +# out_type: Optional[torch.dtype] = None, +# maybe_schedule: Optional[str] = None) -> torch.Tensor: +# m = a.size(0) +# n = b_q.size(1) +# out_dtype = out_type if out_type is not None else torch.bfloat16 +# return torch.empty((m, n), device=a.device, dtype=out_dtype) + +# @register_fake("_C::cutlass_pack_scale_fp8") +# def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: +# return torch.empty_like(scales, memory_format=torch.contiguous_format) + +# @register_fake("_C::cutlass_encode_and_reorder_int4b") +# def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: +# return torch.empty_like(b, memory_format=torch.contiguous_format) + # if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @@ -920,7 +917,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, # m = a.size(0) # return torch.empty((m, n), device=a.device, dtype=a.dtype) - if hasattr(torch.ops._C, "ggml_dequantize"): @register_fake("_C::ggml_dequantize") @@ -1291,6 +1287,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, blockscale_offsets) +def get_cutlass_moe_mm_problem_sizes( + topk_ids: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None): + """ + Compute only the per-expert problem sizes needed by the two grouped matrix + multiplications used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token→expert mapping) and computes: + - problem_sizes1, problem_sizes2: M×N×K sizes of each expert's + multiplication for the two grouped MMs + used in the fused MoE operation. + """ + return torch.ops._C.get_cutlass_moe_mm_problem_sizes( + topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, + blockscale_offsets) + + def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): """ Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. @@ -1484,6 +1502,30 @@ def machete_prepack_B( group_scales_type) +# CUTLASS W4A8 +def cutlass_w4a8_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size, + b_channel_scales, a_token_scales, + out_type, maybe_schedule) + + +def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_pack_scale_fp8(scales) + + +def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_encode_and_reorder_int4b(b) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") @@ -1773,15 +1815,6 @@ def scaled_int8_quant( return output, input_scales, input_azp -# qqq ops -def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - # gguf def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, dtype: Optional[torch.dtype]) -> torch.Tensor: @@ -1918,6 +1951,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, gating_output) +def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, + num_expert_group: int, topk_group: int, topk: int, + renormalize: bool, routed_scaling_factor: float): + if not current_platform.is_cuda(): + raise NotImplementedError("The fused grouped_topk kernel is only " + "available on CUDA platforms") + return torch.ops._moe_C.grouped_topk(scores, scores_with_bias, + num_expert_group, topk_group, topk, + renormalize, routed_scaling_factor) + + def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], b_qweight: torch.Tensor, b_bias: Optional[torch.Tensor], @@ -2045,6 +2089,20 @@ def concat_and_cache_mla( scale) +def cp_fused_concat_and_cache_mla( + kv_c: torch.Tensor, + k_pe: torch.Tensor, + cp_local_token_select_indices: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla( + kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping, + kv_cache_dtype, scale) + + def copy_blocks(key_caches: list[torch.Tensor], value_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: @@ -2068,14 +2126,28 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) -def gather_cache(src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - seq_starts: Optional[torch.Tensor] = None) -> None: - torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, seq_starts) +def gather_and_maybe_dequant_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( + src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, seq_starts) + + +def cp_gather_cache(src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.cp_gather_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, seq_starts) def get_device_attribute(attribute: int, device: int) -> int: @@ -2378,9 +2450,92 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): N = mat2.size(0) return torch.empty((M, N), dtype=out_dtype) + +class CPUDNNLGEMMHandler: + + def __init__(self) -> None: + self.handler: Optional[int] = None + self.n = -1 + self.k = -1 + + def __del__(self): + if self.handler is not None: + torch.ops._C.release_dnnl_matmul_handler(self.handler) + + +def create_onednn_scaled_mm( + weight: torch.Tensor, # [K, N] + weight_scales: torch.Tensor, + output_type: torch.dtype, + dynamic_quant: bool, + use_azp: bool, + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( + weight, weight_scales, output_type, dynamic_quant, use_azp, + primitive_cache_size) + return handler + + +def onednn_scaled_int8_quant(input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True): + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + token_num = input.numel() // input.shape[-1] + input = input.view((token_num, input.shape[-1])) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((token_num, 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp + + +def onednn_scaled_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + output: torch.Tensor, + input_scale: Optional[torch.Tensor], + input_zp: Optional[torch.Tensor], + input_zp_adj: Optional[torch.Tensor], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp, + input_zp_adj, bias, dnnl_handler.handler) + + return output + direct_register_custom_op( op_name="awq_gemm", op_func=awq_gemm, mutates_args=[], fake_impl=awq_gemm_fake, -) \ No newline at end of file +) diff --git a/vllm/assets/image.py b/vllm/assets/image.py index c977242a3d48427255971cf9f3836f445ca5a4ad..c8f8d43a983553ca6ea02dae9e90c76e63e3f093 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -11,7 +11,7 @@ from .base import get_vllm_public_assets VLM_IMAGES_DIR = "vision_model_images" -ImageAssetName = Literal["stop_sign", "cherry_blossom"] +ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato"] @dataclass(frozen=True) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 344040586a53266b447b17fefc984e4b58b0c6f0..dcb2aa68fbee9b4a3e3b2b97489aae93e1c782ec 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -14,7 +14,6 @@ __all__ = [ "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", - "Attention", "AttentionState", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index d21f07756871aa29fc9f6ee604e690bf05fe8541..0b9c625533cb71b2c0a17d1da8ccb89ec68ab66d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. - TODO(luka) merge parameters into QuantDescriptor - :param dtype: quantized dtype - :param static: static or dynamic quantization - :param group_shape: quant group shape. + :param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization """ return False @@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index fac3c318a87a028dea1adeeac84ed303ed86c2e5..caa02530d2fd6c12946644ce38e69982f4dd20e3 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -800,23 +800,33 @@ class DifferentialFlashAttentionImpl(AttentionImpl): attn_metadata: DifferentialFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + layer: Attention layer instance. + q: Query tensor with shape = [num_tokens, num_heads, head_size] + k: Key tensor with shape = [num_tokens, num_kv_heads, head_size] + v: Value tensor with shape = [num_tokens, num_kv_heads, head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size, num_kv_heads, head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. + output: Output tensor with shape [num_tokens, num_heads, head_size] + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. NOTE: It in-place updates the output tensor. NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for DifferentialFlashAttentionImpl") + if self.lambda_full is None: self.lambda_init = self.differential_flash_attention_config[ "lambda_init"] diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index a37f6fcff5c3ea6da6bb7a0b4d56abdac466033e..8190123f4dc7e6e033d7403d88d448d00dec9e20 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -376,6 +376,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): attn_metadata: DualChunkFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with DualChunkFlashAttention. Args: @@ -391,7 +392,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): """ assert output is None, "Output tensor not supported for DualChunk" - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 29125c0ec0530bec8005ba09b3a1b7e5341245b5..85b256c188370389e73993febf815cdb67a509cd 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -603,6 +603,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -611,7 +612,8 @@ class FlashAttentionImpl(AttentionImpl): key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size, num_kv_heads, head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. @@ -622,7 +624,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") @@ -925,7 +927,7 @@ class FlashAttentionImpl(AttentionImpl): def _get_query_key_seq_metadata( - attn_metadata, + attn_metadata: FlashAttentionMetadata, is_prompt: bool, attn_type: str, ) -> tuple: diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index d21e677587b1b12466455fdcd9fc9817c401123e..cf2de9cc7dd1b3ccc3b3cb14b9b4f6b59492553a 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -837,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): self.context_chunk_workspace_size // num_prefills_with_context # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel cannot + # handle `context_chunk_starts` that are not aligned to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) @@ -1090,6 +1090,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): prefill_metadata = attn_metadata.prefill_metadata assert prefill_metadata is not None @@ -1111,12 +1112,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_tables, cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], batch_size=prefill_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.context_chunk_starts[i], ) @@ -1173,6 +1176,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: prefill_metadata = attn_metadata.prefill_metadata @@ -1208,7 +1212,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): # ROCm flash_attn_varlen_func will return 3 objects instead of 2 suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1245,12 +1249,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLAImplBase") @@ -1298,7 +1303,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: decode_q_nope, decode_q_pe = decode_q.split( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index ece6f5d61f34e465e03b2b857a6569680612d9c0..98f6afb5d0e0b2a5635850df911bf89acb5f63c5 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -23,7 +23,7 @@ from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform from vllm.utils import SUPPORT_TC, gpuname @@ -549,11 +549,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): if self.use_triton_flash_attn: - return dtype == current_platform.fp8_dtype( - ) and static and group_shape == GroupShape.PER_TENSOR + return quant_key == kFp8StaticTensorSym # Only supported in the Triton backend return False @@ -568,6 +566,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): attn_metadata: ROCmFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -605,17 +604,18 @@ class ROCmFlashAttentionImpl(AttentionImpl): use prefill sequence attributes Args: + layer: Attention layer instance. query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size * num_kv_heads * head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally + output: Optional output tensor. + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. Returns: shape = [num_tokens, num_heads * head_size] """ @@ -626,6 +626,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): "fused output quantization only supported for Triton" " implementation in ROCMFlashAttentionImpl for now") + if output_block_scale is not None: + raise NotImplementedError( + "fused nvfp4 output quantization is not supported" + " for ROCMFlashAttentionImpl") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 772b628147e309cd7738b445f8ad16621bd16d5e..d973027e4ef7f99f01e2a1a12df6ff6fc7c76501 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -585,7 +585,7 @@ def get_num_prefill_decode_query_kv_tokens( Raises: AssertionError: If the number of encoder tokens in `attn_metadata` - is `None` when required for the calculations. + is `None` when required for the calculations. """ num_prefill_query_tokens = 0 num_decode_query_tokens = 0 diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6e8a7be4e3da5f54fabd14bf40a4d9319d15d5bd..56d69b893534588a2851eb62c7ed57ed5434f088 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -439,6 +439,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): attn_metadata: "XFormersMetadata", output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -477,21 +478,22 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): max_encoder_seq_len) Args: + layer: Attention layer instance. query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size * num_kv_heads * head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally + output: Optional output tensor. + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersImpl") @@ -654,7 +656,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): for API spec. Args: - output: shape = [num_prefill_tokens, num_heads, head_size] query: shape = [num_prefill_tokens, num_heads, head_size] key: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3214e2bec4dcd0595770241c154da847c095c68b..5ef42b0556eb8adf1782b477c7a16b22ec60a470 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -54,7 +55,7 @@ def check_xformers_availability(): return USE_XFORMERS_OPS -class Attention(nn.Module): +class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors @@ -128,11 +129,17 @@ class Attention(nn.Module): self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - # We also keep the float32 versions of k/v_scale for attention - # backends that don't support tensors (Flashinfer) + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: Optional[float] = None + self.use_mla = use_mla self.num_heads = num_heads self.head_size = head_size @@ -183,8 +190,7 @@ class Attention(nn.Module): # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( - ) and not current_platform.is_cpu() + self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config @@ -291,6 +297,7 @@ class Attention(nn.Module): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once @@ -488,6 +495,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -503,7 +511,8 @@ def unified_attention_with_output( kv_cache, attn_metadata, output=output, - output_scale=output_scale) + output_scale=output_scale, + output_block_scale=output_block_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -515,6 +524,7 @@ def unified_attention_with_output_fake( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: return @@ -522,7 +532,7 @@ def unified_attention_with_output_fake( direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, - mutates_args=["output"], + mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 892077ba91e07a40f0f3f360ab0cfa20a23334b9..087c5004bde06815eee2767dd4d2a01b5116ff50 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -6,12 +6,13 @@ from typing import List, Optional import torch from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, QuantizationConfig from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, make_local_attention_virtual_batches, - subclass_attention_backend, subclass_attention_metadata_builder) + subclass_attention_backend) from ..layer import Attention @@ -24,21 +25,23 @@ def create_chunked_local_attention_backend( ) -> type[AttentionBackend]: prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" - def build_preprocess_fn(cm: CommonAttentionMetadata): - return make_local_attention_virtual_batches(attention_chunk_size, cm, - block_size) + underlying_builder = underlying_attn_backend.get_builder_cls() + + class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + common_attn_metadata = make_local_attention_virtual_batches( + attention_chunk_size, common_attn_metadata, block_size) + return super().build(common_prefix_len, common_attn_metadata, + fast_build) - # Dynamically create a new attention backend that wraps the - # underlying attention backend but applies - # `make_local_attention_virtual_batches` before calling `build(...)` - builder_cls = subclass_attention_metadata_builder( - name_prefix=prefix, - builder_cls=underlying_attn_backend.get_builder_cls(), - build_preprocess_fn=build_preprocess_fn) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=builder_cls) + builder_cls=ChunkedLocalAttentionBuilder) return attn_backend diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..cea05df5b96d2ae2014c5108dcb440a03f313f94 --- /dev/null +++ b/vllm/attention/layers/encoder_only_attention.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) + + +@functools.lru_cache +def create_encoder_only_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "EncoderOnlyAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata = copy(common_attn_metadata) + new_common_attn_metadata.causal = False + return super().build(common_prefix_len, new_common_attn_metadata, + fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=EncoderOnlyAttentionBuilder) + + return attn_backend + + +class EncoderOnlyAttention(Attention): + """ + Encoder attention is a special case that doesn't need a KV Cache. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_encoder_only_attention_backend( + underlying_attn_backend) + else: + # in v0 encoder only attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_ONLY, \ + "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_ONLY, + **kwargs) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index dcf99f0da50ae44003ef19eda2922d26f497c308..c7bd4007f61d7c44048dbe5a58754d4b35b4f72e 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -75,8 +75,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, - k_scale = None, - kv_cache_dtype = "auto", + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -91,6 +91,8 @@ def flash_mla_with_kvcache( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + descale_q: (batch_size), torch.float32. Descaling factors for Q. + descale_k: (batch_size), torch.float32. Descaling factors for K. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -99,22 +101,22 @@ def flash_mla_with_kvcache( if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) if current_platform.is_rocm(): - if kv_cache_dtype == "fp8": - out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( - q, - k_cache, - None, - head_dim_v, - cache_seqlens, - block_table, - softmax_scale, - causal, - tile_scheduler_metadata, - num_splits, - k_scale, - "fp8_e4m3", - ) - return out, softmax_lse + # if kv_cache_dtype == "fp8": + # out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( + # q, + # k_cache, + # None, + # head_dim_v, + # cache_seqlens, + # block_table, + # softmax_scale, + # causal, + # tile_scheduler_metadata, + # num_splits, + # k_scale, + # "fp8_e4m3", + # ) + # return out, softmax_lse out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, @@ -126,12 +128,13 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + # descale_q, + # descale_k, ) else: out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, - None, head_dim_v, cache_seqlens, block_table, @@ -139,6 +142,8 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + descale_q, + descale_k, ) return out, softmax_lse diff --git a/vllm/beam_search.py b/vllm/beam_search.py index f3bc4218323d80f45fca8cb820065edd827648bc..5a2e79e1b5c74d627c8fac09ee490d36ad14bdd2 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -18,7 +18,7 @@ class BeamSearchSequence: The text field is optional and will only be filled when the sequence is about to be returned to the user. """ - # The tokens includes the prompt. + # The tokens include the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] lora_request: Optional[LoRARequest] = None diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 72d7ce49b8e143d368821c79380e12f976256f8d..93519b5ba1523bf9e9a4b6f88ab5b85cdd6a2909 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,17 +11,21 @@ generation. Supported dataset types include: - HuggingFace - VisionArena """ +import ast import base64 import io import json import logging +import math import random from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import suppress +from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast import numpy as np from PIL import Image @@ -69,13 +73,14 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: Union[str, Any] + prompt: Union[str, list[str]] prompt_len: int expected_output_len: int multi_modal_data: Optional[ Union[MultiModalDataDict, dict, list[dict]] ] = None lora_request: Optional[LoRARequest] = None + request_id: Optional[str] = None # ----------------------------------------------------------------------------- @@ -112,7 +117,9 @@ class BenchmarkDataset(ABC): def apply_multimodal_chat_transformation( self, prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + mm_content: Optional[ + Union[MultiModalDataDict, dict, list[dict]] + ] = None) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -120,7 +127,15 @@ class BenchmarkDataset(ABC): """ content = [{"text": prompt, "type": "text"}] if mm_content is not None: - content.append(mm_content) + if isinstance(mm_content, list): + content.extend(cast(list[dict[str, Any]], mm_content)) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "Could not process multimodal content of type: " + + f"{type(mm_content)}" + ) return [{"role": "user", "content": content}] def load_data(self) -> None: @@ -183,7 +198,8 @@ class BenchmarkDataset(ABC): @abstractmethod def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int) -> list[SampleRequest]: + num_requests: int, + request_id_prefix: str = "") -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -194,6 +210,8 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. + request_id_prefix (str) The prefix of request_id. + Returns: list[SampleRequest]: A list of sample requests generated from the @@ -201,8 +219,12 @@ class BenchmarkDataset(ABC): """ raise NotImplementedError("sample must be implemented in subclasses.") - def maybe_oversample_requests(self, requests: list[SampleRequest], - num_requests: int) -> None: + def maybe_oversample_requests( + self, + requests: list[SampleRequest], + num_requests: int, + request_id_prefix: str = "", + ) -> None: """ Oversamples the list of requests if its size is less than the desired number. @@ -211,11 +233,17 @@ class BenchmarkDataset(ABC): requests (List[SampleRequest]): The current list of sampled requests. num_requests (int): The target number of requests. + request_id_prefix (str) The prefix of the request ids. + """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, - k=num_requests - len(requests)) + additional = deepcopy( + random.choices(requests, k=num_requests - len(requests)) + ) + for i in range(len(additional)): + req = additional[i] + req.request_id = request_id_prefix + str(len(requests) + i) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -266,7 +294,7 @@ def process_image(image: Any) -> Mapping[str, Any]: """ Process a single image input and return a multimedia content dictionary. - Supports three input types: + Supports the following input types: 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key containing raw image data. - Loads the bytes as a PIL.Image.Image. @@ -306,94 +334,592 @@ def process_image(image: Any) -> Mapping[str, Any]: " or str or dictionary with raw image bytes.") +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and 'bytes' in video: + video_bytes = video['bytes'] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": { + "url": f"data:video/mp4;base64,{video_base64}" + }, + } + + if isinstance(video, str): + video_url = (video if video.startswith( + ("http://", "file://")) else f"file://{video}") + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- class RandomDataset(BenchmarkDataset): + """ + Synthetic text-only dataset for serving/throughput benchmarks. + + Strategy: + - Sample input/output token lengths per request from integer-uniform ranges + around configured means (controlled by range_ratio). + - Prepend a fixed random prefix of length prefix_len. + - Generate the remaining tokens as a reproducible sequence: + (offset + index + arange(input_len)) % vocab_size. + - Decode then re-encode/truncate to ensure prompt token counts match. + - Uses numpy.default_rng seeded with random_seed for reproducible sampling. + """ # Default values copied from benchmark_serving.py for the random dataset. DEFAULT_PREFIX_LEN = 0 DEFAULT_RANGE_RATIO = 0.0 DEFAULT_INPUT_LEN = 1024 DEFAULT_OUTPUT_LEN = 128 - def __init__( - self, - **kwargs, - ) -> None: + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - random.seed(self.random_seed) - np.random.seed(self.random_seed) + # Use numpy's default_rng for deterministic sampling + # Do not use random.seed() or np.random.seed() elsewhere in this class. + # This ensures that the RNG is isolated from global RNG state. + self._rng = np.random.default_rng(self.random_seed) def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + batchsize: int = 1, **kwargs, ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" + + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer ) + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + requests = [] + for i in range(num_requests): + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + request_id=request_id_prefix + str(i), + ) + ) + # only used for embeddings benchmark. + if batchsize > 1: + batch_requests = [] + # Create batched requests + for i in range(0, num_requests, batchsize): + batch = requests[i : i + batchsize] + batch_requests.append( + SampleRequest( + prompt=[req.prompt for req in batch], + prompt_len=sum(req.prompt_len for req in batch), + expected_output_len=0, + request_id=request_id_prefix + str(i // batchsize), + ) + ) + requests = batch_requests + return requests + + def get_prefix( + self, tokenizer: PreTrainedTokenizerBase, prefix_len: int + ) -> list[int]: + """ + Get the prefix for the dataset. + """ + return ( + self._rng.integers( + 0, tokenizer.vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - output_high = int(output_len * (1 + range_ratio)) + def get_sampling_params( + self, + num_requests: int, + range_ratio: float, + input_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get the sampling parameters for the dataset. + """ + # Enforce range_ratio < 1 + if not (0.0 <= range_ratio < 1.0): + raise ValueError("range_ratio must be in [0, 1).") + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special_tokens) + # Bounds use floor for low and ceil for high + input_low = math.floor(real_input_len * (1 - range_ratio)) + input_high = math.ceil(real_input_len * (1 + range_ratio)) + output_low = math.floor(output_len * (1 - range_ratio)) + output_high = math.ceil(output_len * (1 + range_ratio)) + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. + output_low = max(output_low, 1) + + if input_low > input_high: + raise ValueError( + "Invalid input sampling interval: " + f"low={input_low} > high={input_high}" + ) + if output_low > output_high: + raise ValueError( + "Invalid output sampling interval: " + f"low={output_low} > high={output_high}" + ) - # Add logging for debugging logger.info( "Sampling input_len from [%s, %s] and output_len from [%s, %s]", - input_low, input_high, output_low, output_high) + input_low, + input_high, + output_low, + output_high, + ) - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, + input_lens = self._rng.integers(input_low, input_high + 1, + size=num_requests) + output_lens = self._rng.integers(output_low, output_high + 1, + size=num_requests) + offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) + return input_lens, output_lens, offsets - requests = [] + def generate_token_sequence( + self, + *, + tokenizer: PreTrainedTokenizerBase, + prefix_token_ids: list[int], + prefix_len: int, + vocab_size: int, + input_len: int, + offset: int, + index: int, + ) -> tuple[str, int]: + """ + Returns (prompt, total_input_len). + + NOTE: After decoding the prompt we have to encode and decode it again. + This is done because in some cases N consecutive tokens + give a string tokenized into != N number of tokens. + For example for GPT2Tokenizer: + [6880, 6881] -> ['Ġcalls', 'here'] -> + [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + To avoid uncontrolled change of the prompt length, + the encoded sequence is truncated before being decode again. + """ + # Build the inner sequence by sampling sequentially from the vocab + inner_seq = ((offset + index + np.arange(input_len)) + % vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + + # Decode, then re-encode and truncate to preserve token count invariants + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_len) + + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) + + return prompt, total_input_len + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- + +class RandomMultiModalDataset(RandomDataset): + """ + Synthetic multimodal dataset (text + images) that extends RandomDataset. + + Status: + - Images: supported via synthetic RGB data. + - Video: not yet supported (TODO: implement video generation method). + - Audio: not yet supported. + + Sampling overview: + 1) Number of items per request is sampled uniformly from the integer range + [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is + `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. + The maximum is further clamped to the sum of per-modality limits. + 2) Each item’s modality and shape is sampled from `bucket_config`, a dict + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized + to sum to 1. + 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. + When a modality reaches its cap, all of its buckets are excluded and the + remaining probabilities are renormalized. + + Example bucket configuration: + {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). + OBS.: Only image sampling is supported for now. + """ + + IS_MULTIMODAL = True + # NOTE: video sampling is WIP. Setting it to 0. + DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0} + + DEFAULT_BASE_ITEMS_PER_REQUEST = 1 + DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0 + DEFAULT_MM_ITEM_BUCKET_CONFIG = { + (256, 256, 1): 0.5, + (720, 1280, 1): 0.5, + (720, 1280, 16): 0.0, + } + DEFAULT_ENABLE_MULTIMODAL_CHAT = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + + def generate_synthetic_image(self, width: int, height: int) -> Image.Image: + """Generate synthetic PIL image with random RGB values. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. + We could consider a “low-freq” mode (e.g., noise blur) + to emulate network realism instead of max stress. + """ + random_pixels = self._rng.integers( + 0, + 256, + (height, width, 3), + dtype=np.uint8, + ) + return Image.fromarray(random_pixels) + + def generate_synthetic_video(self, width: int, + height: int, + num_frames: int) -> Any: + """Generate synthetic video with random values. + + TODO: Finish this method. + """ + raise NotImplementedError("Video sampling is WIP.") + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") + + def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], + float]) -> dict[tuple[int, int, int], float]: + """ + Remove zero probability entries + and normalize the bucket config to sum to 1. + """ + # Raise error if value is negative + if any(v < 0 for v in bucket_config.values()): + raise ValueError("Bucket config values must be non-negative.") + # Remove zero probability entries + bucket_config = {k: v for k, v in bucket_config.items() if v > 0} + # if bucket config is empty, raise error + if not bucket_config: + raise ValueError("Got invalid bucket config. " + "Bucket config values must be non-zero.") + # Normalize the remaining bucket config to sum to 1 + total = sum(bucket_config.values()) + return {k: v / total for k, v in bucket_config.items()} + + + def generate_mm_item(self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image(self.generate_synthetic_image( + mm_item_config[1], + mm_item_config[0])) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video(self.generate_synthetic_video( + mm_item_config[1], + mm_item_config[0], + mm_item_config[2])) + else: + raise ValueError(f"Invalid multimodal item configuration: " + f"{mm_item_config}") + + + def get_mm_item_sampling_params( + self, + base_items_per_request: int, + num_mm_items_range_ratio: float, + limit_mm_per_prompt: dict[str, int], + bucket_config: dict[tuple[int, int, int], float], + ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]: + """ + Get the sampling parameters for the multimodal items. + """ + # Enforce num_mm_items_range_ratio <= 1 + if not (0.0 <= num_mm_items_range_ratio <= 1.0): + raise ValueError("num_mm_items_range_ratio must be in [0, 1].") + + # Ensure modalities to sample are in limit_mm_per_prompt + for k, v in bucket_config.items(): + # get modality from bucket config + modality = self.map_config_to_modality(k) + if modality not in limit_mm_per_prompt: + raise ValueError(f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}") + + # Remove zero probability entries + # and normalize bucket config to sum to 1 + bucket_config = self.normalize_bucket_config(bucket_config) + logger.info( + "Normalized bucket config: %s", bucket_config, + ) + # Only consider limit per prompt for modalities in bucket config + allowed_modalities = {self.map_config_to_modality(cfg) + for cfg in bucket_config} + limit_mm_per_prompt = { + k: v for k, v in limit_mm_per_prompt.items() + if k in allowed_modalities} + if not limit_mm_per_prompt: + raise ValueError("No valid limits for modalities present in " + "bucket_config.") + + logger.info( + "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt, + ) + + # Get max and min num mm items and ensure + # it is at most the sum of limit_mm_per_prompt for all modalities + max_num_mm_items = min( + sum(limit_mm_per_prompt.values()), + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) + ) + # Ensure min num mm items is at least 0 + min_num_mm_items = max( + 0, + math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + ) + # Raise error if min num mm items is greater than max num mm items + if min_num_mm_items > max_num_mm_items: + raise ValueError(f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}") + + logger.info( + "Sampling number of multimodal items from [%s, %s]", + min_num_mm_items, max_num_mm_items, + ) + + return ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) + + def get_mm_item_iterator( + self, + min_num_mm_items: int, + max_num_mm_items: int, + bucket_config: dict[tuple[int, int, int], float], + limit_mm_per_prompt: dict[str, int], + ) -> Iterator[tuple[int,int, int]]: + """ + Iterator over the multimodal items for each request + whose size is between min_num_mm_items and max_num_mm_items. + + Loop over the bucket config and sample a multimodal item. + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt + for all modalities is reached. + + Note: + - This function operates on a per-request shallow copy of + `bucket_config` (tuple->float). The original dict passed to + `sample` is not mutated. If this ever changes, a test + is implemented and will fail. + """ + # Get the number of multimodal items to sample + request_num_mm_items = int( + self._rng.integers(min_num_mm_items, max_num_mm_items + 1) + ) + # If request_num_mm_items is 0, yield an empty iterator + if request_num_mm_items == 0: + return + # Initialize modality counters + modality_counter = {self.map_config_to_modality(k): 0 + for k in bucket_config} + # Copy the bucket config to avoid modifying the original + bucket_config_copy = bucket_config.copy() + # Loop over the number of multimodal items to sample + while sum(modality_counter.values()) < request_num_mm_items: + # Sample a multimodal item config + mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), + p=list(bucket_config_copy.values())) + modality = self.map_config_to_modality(mm_item_config) + # Check that modality count is less than limit per prompt + if modality_counter[modality] < limit_mm_per_prompt[modality]: + modality_counter[modality] += 1 + yield ( + mm_item_config + ) + else: + # If the counter is greater than the limit per prompt + # set all multimodal items of this modality to 0 + for k, v in bucket_config_copy.items(): + if self.map_config_to_modality(k) == modality: + bucket_config_copy[k] = 0 + # If all configs are 0, break the loop + # This should not happen as request_num_mm_items is at most + # the sum of limit_mm_per_prompt for all modalities + if all(v == 0 for v in bucket_config_copy.values()): + logger.warning("Exhausted all multimodal items " + "of modality %s", + modality) + break + # Renormalize the bucket config + bucket_config_copy = self.normalize_bucket_config( + bucket_config_copy) + + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, + base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, + num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + bucket_config: dict[tuple[int, int, int], float] = + DEFAULT_MM_ITEM_BUCKET_CONFIG, + enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, + **kwargs, + ) -> list[SampleRequest]: + + # NOTE: Video sampling is WIP. Raise error if video is in bucket config + # and probability is non-zero. + if any(self.map_config_to_modality(cfg) == "video" and p > 0 + for cfg, p in bucket_config.items()): + raise NotImplementedError("Video sampling not implemented; " + "set its probability to 0.") + + # Get the sampling parameters for the dataset + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer + ) + + ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) = self.get_mm_item_sampling_params( + base_items_per_request, + num_mm_items_range_ratio, + limit_mm_per_prompt, + bucket_config, + ) + + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) + vocab_size = tokenizer.vocab_size + # Add synthetic multimodal items to each request + mm_requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ġcalls', 'here'] -> - # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:total_input_len] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) - requests.append( - SampleRequest( + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + # Get multimodal item iterator for a given request + mm_item_iterator = self.get_mm_item_iterator( + min_num_mm_items, + max_num_mm_items, + bucket_config, + limit_mm_per_prompt, + ) + + mm_content = cast(list[dict[str, Any]], [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ]) + + if enable_multimodal_chat: + # NOTE: For now this option is only provided for completeness + # given that the serve.py benchmark currently does not use it. + mm_chat_prompt: Any = prompt + mm_chat_prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sample_request = SampleRequest( + prompt=mm_chat_prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=None, + request_id=request_id_prefix + str(i), + ) + else: + sample_request = SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), - )) - return requests - + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + ) + mm_requests.append(sample_request) + return mm_requests # ----------------------------------------------------------------------------- # ShareGPT Dataset Implementation @@ -432,9 +958,11 @@ class ShareGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: samples: list = [] + ind = 0 for entry in self.data: if len(samples) >= num_requests: break @@ -455,9 +983,10 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None): continue - # TODO: Also support ShareGPT4Video. if image_path := entry.get("image"): mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) else: mm_content = None if enable_multimodal_chat: @@ -470,8 +999,10 @@ class ShareGPTDataset(BenchmarkDataset): expected_output_len=new_output_len, lora_request=lora_request, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) - self.maybe_oversample_requests(samples, num_requests) + ind += 1 + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -488,8 +1019,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser): type=str, default="random", choices=[ - "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom", - "prefix_repetition" + "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", + "custom", "prefix_repetition" ], help="Name of the dataset to benchmark on.", ) @@ -589,6 +1120,103 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "context length sampled from [input_len * (1 - range_ratio), " "input_len * (1 + range_ratio)]."), ) + random_group.add_argument( + "--random-batch-size", + type=int, + default=1, + help=("Batch size for random sampling. " + "Only used for embeddings benchmark."), + ) + + # random multimodal dataset options + random_mm_group = parser.add_argument_group( + "random multimodal dataset options extended from random dataset") + random_mm_group.add_argument( + "--random-mm-base-items-per-request", + type=int, + default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, + help=( + "Base number of multimodal items per request for random-mm. " + "Actual per-request count is sampled around this base using " + "--random-mm-num-mm-items-range-ratio." + ), + ) + random_mm_group.add_argument( + "--random-mm-num-mm-items-range-ratio", + type=float, + default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + help=( + "Range ratio r in [0, 1] for sampling items per request. " + "We sample uniformly from the closed integer range " + "[floor(n*(1-r)), ceil(n*(1+r))] " + "where n is the base items per request. " + "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped " + "to the sum of per-modality limits from " + "--random-mm-limit-mm-per-prompt. " + "An error is raised if the computed min exceeds the max." + ), + ) + random_mm_group.add_argument( + "--random-mm-limit-mm-per-prompt", + type=json.loads, + default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, + help=( + "Per-modality hard caps for items attached per request, e.g. " + "'{\"image\": 3, \"video\": 0}'. The sampled per-request item " + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + "OBS.: Only image sampling is supported for now." + ), + ) + + def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]: + # If already a dict (e.g., programmatic call), normalize keys + def normalize(d: dict) -> dict[tuple[int, int, int], float]: + out: dict[tuple[int, int, int], float] = {} + for k, val in d.items(): + key = k + if isinstance(key, str): + with suppress(Exception): + key = ast.literal_eval(key) + if not (isinstance(key, tuple) and len(key) == 3 + and all(isinstance(x, int) for x in key)): + raise ValueError( + f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." + ) + out[(int(key[0]), int(key[1]), int(key[2]))] = float(val) + return out + + if isinstance(v, dict): + return normalize(v) + if isinstance(v, str): + # Python literal (supports tuple keys) + parsed = ast.literal_eval(v) + if not isinstance(parsed, dict): + raise ValueError("Bucket config must parse to a dict.") + return normalize(parsed) + raise ValueError("Unsupported value for --random-mm-bucket-config.") + + random_mm_group.add_argument( + "--random-mm-bucket-config", + type=_parse_mm_bucket_config, + default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, + help=( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 2 modalities: images and videos. " + "An bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: images with resolution 720x1280 w.p. 0.4 " + "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + "OBS bis.: Only image sampling is supported for now." + ), + ) hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", @@ -647,6 +1275,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "sonnet": @@ -660,6 +1289,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -671,6 +1301,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "hf": @@ -716,10 +1347,11 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: "openai-chat", "openai-audio", ]: - # multi-modal benchmark is only available on OpenAI Chat backend. + # multi-modal benchmark is only available on OpenAI Chat + # endpoint-type. raise ValueError( "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backend.") + "'openai-audio' endpoint-type.") input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -730,31 +1362,54 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, ) else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), + "random": lambda: RandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, + ), + "random-mm": + lambda: RandomMultiModalDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, ), "prefix_repetition": lambda: PrefixRepetitionRandomDataset( @@ -766,10 +1421,18 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: suffix_len=args.prefix_repetition_suffix_len, num_prefixes=args.prefix_repetition_num_prefixes, output_len=args.prefix_repetition_output_len, + request_id_prefix=args.request_id_prefix, ), } try: + # Enforce endpoint compatibility for multimodal datasets. + if args.dataset_name == "random-mm" and args.endpoint_type not in [ + "openai-chat"]: + raise ValueError( + "Multi-modal content (images) is only supported on " + "'openai-chat' backend." + ) input_requests = dataset_mapping[args.dataset_name]() except KeyError as err: raise ValueError(f"Unknown dataset: {args.dataset_name}") from err @@ -839,10 +1502,11 @@ class CustomDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["prompt"] @@ -864,8 +1528,10 @@ class CustomDataset(BenchmarkDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -909,6 +1575,7 @@ class SonnetDataset(BenchmarkDataset): input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -934,6 +1601,7 @@ class SonnetDataset(BenchmarkDataset): prefix_lines = self.data[:num_prefix_lines] samples = [] + ind = 0 while len(samples) < num_requests: extra_lines = random.choices(self.data, k=num_input_lines - num_prefix_lines) @@ -949,7 +1617,9 @@ class SonnetDataset(BenchmarkDataset): if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(ind), )) + ind += 1 return samples @@ -1000,6 +1670,7 @@ class BurstGPTDataset(BenchmarkDataset): num_requests: int, max_loras: Optional[int] = None, lora_path: Optional[str] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -1020,6 +1691,7 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, + request_id=request_id_prefix + str(i), )) return samples @@ -1075,11 +1747,13 @@ class ConversationDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter( lambda x: len(x["conversations"]) >= 2) sampled_requests = [] + ind = 0 dynamic_output = output_len is None for item in filtered_data: @@ -1111,8 +1785,11 @@ class ConversationDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1141,12 +1818,13 @@ class VisionArenaDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) @@ -1168,8 +1846,10 @@ class VisionArenaDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1198,15 +1878,18 @@ class InstructCoderDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( @@ -1224,8 +1907,10 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1255,13 +1940,14 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["turns"][0] @@ -1282,8 +1968,10 @@ class MTBenchDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1305,8 +1993,10 @@ class AIMODataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs) -> list: sampled_requests = [] + ind = 0 dynamic_output = output_len is None for item in self.data: @@ -1331,8 +2021,12 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, + request_id=request_id_prefix + str(ind), + )) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1403,13 +2097,14 @@ class NextEditPredictionDataset(HuggingFaceDataset): } def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", **kwargs): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( self.dataset_path) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.dataset_path}") samples = [] - for sample in self.data: + for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) samples.append( SampleRequest( @@ -1417,10 +2112,11 @@ class NextEditPredictionDataset(HuggingFaceDataset): prompt_len=len(tokenizer(sample["prompt"]).input_ids), expected_output_len=len( tokenizer(sample["expected_output"]).input_ids), + request_id=request_id_prefix + str(i), )) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests) + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -1470,6 +2166,7 @@ class ASRDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len @@ -1477,6 +2174,7 @@ class ASRDataset(HuggingFaceDataset): prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] + ind = 0 skipped = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1496,7 +2194,9 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) + ind += 1 if skipped: logger.warning( "%d samples discarded from dataset due to" @@ -1504,7 +2204,8 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1541,11 +2242,13 @@ class MLPerfDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: # Force dynamic output length based on reference completion. dynamic_output = output_len is None sampled_requests: list[SampleRequest] = [] + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1580,10 +2283,13 @@ class MLPerfDataset(HuggingFaceDataset): prompt=prompt_formatted, prompt_len=prompt_len, expected_output_len=expected_output_len, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1616,6 +2322,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): suffix_len: int = DEFAULT_SUFFIX_LEN, num_prefixes: int = DEFAULT_NUM_PREFIXES, output_len: int = DEFAULT_OUTPUT_LEN, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: vocab_size = tokenizer.vocab_size diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 47bc288774504d91abe8c9649694376d802633e6..6bb2a497119e9f9f713f576a7a38282d380143b7 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -9,7 +9,7 @@ import sys import time import traceback from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import aiohttp from tqdm.asyncio import tqdm @@ -28,9 +28,10 @@ class RequestFuncInput: model_name: Optional[str] = None logprobs: Optional[int] = None extra_body: Optional[dict] = None - multi_modal_content: Optional[dict | list[dict]] = None + multi_modal_content: Optional[Union[dict, list[dict]]] = None ignore_eos: bool = False language: Optional[str] = None + request_id: Optional[str] = None @dataclass @@ -68,8 +69,8 @@ async def async_request_openai_completions( ), "OpenAI Completions API URL must end with 'completions' or 'profile'." payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -87,6 +88,8 @@ async def async_request_openai_completions( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -132,7 +135,7 @@ async def async_request_openai_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" @@ -210,6 +213,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -249,7 +254,7 @@ async def async_request_openai_chat_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): @@ -311,6 +316,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id # Send audio file def to_bytes(y, sr): @@ -387,12 +394,61 @@ async def async_request_openai_audio( return output +async def async_request_openai_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +): + api_url = request_func_input.api_url + assert api_url.endswith( + "embeddings" + ), "OpenAI Embeddings API URL must end with 'embeddings'." + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + payload = { + "model": request_func_input.model, + "input": request_func_input.prompt, + } + + output = RequestFuncOutput() + st = time.perf_counter() + try: + async with session.post( + url=api_url, + headers=headers, + json=payload + ) as response: + if response.status == 200: + output.latency = time.perf_counter() - st + data = await response.json() + output.success = True + output.generated_text = "" + output.prompt_len = data.get( + "usage", {}).get( + "prompt_tokens", 0) + else: + output.success = False + output.error = response.reason or "" + except Exception as e: + output.success = False + output.error = str(e) + + if pbar: + pbar.update(1) + return output + + # TODO: Add more request functions for different API protocols. ASYNC_REQUEST_FUNCS = { "vllm": async_request_openai_completions, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "openai-audio": async_request_openai_audio, + "openai-embeddings": async_request_openai_embeddings, } OPENAI_COMPATIBLE_BACKENDS = [ diff --git a/vllm/benchmarks/lib/utils.py b/vllm/benchmarks/lib/utils.py index 5f95fdcc758295b84ea16a1bfa14df0d8ce39899..0c27687dcf16db0d1b84b5b1c1728ef7a793199d 100644 --- a/vllm/benchmarks/lib/utils.py +++ b/vllm/benchmarks/lib/utils.py @@ -54,7 +54,12 @@ class InfEncoder(json.JSONEncoder): def clear_inf(self, o: Any): if isinstance(o, dict): - return {k: self.clear_inf(v) for k, v in o.items()} + return { + str(k) + if not isinstance(k, (str, int, float, bool, type(None))) + else k: self.clear_inf(v) + for k, v in o.items() + } elif isinstance(o, list): return [self.clear_inf(v) for v in o] elif isinstance(o, float) and math.isinf(o): diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 7bf04c753241153151fe9f3f21aac184d9f6e88b..abb838316cd31cdce50f07f19c49342edf8f03cf 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -4,7 +4,7 @@ r"""Benchmark online serving throughput. On the server side, run one of the following commands to launch the vLLM OpenAI API server: - vllm serve <your_model> <engine arguments> + vllm serve <your_model> <engine arguments> On the client side, run: vllm bench serve \ @@ -26,6 +26,7 @@ import warnings from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime +from enum import Enum from typing import Any, Literal, Optional import aiohttp @@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer MILLISECONDS_TO_SECONDS_CONVERSION = 1000 +class TaskType(Enum): + GENERATION = "generation" + EMBEDDING = "embedding" + + @dataclass class BenchmarkMetrics: completed: int @@ -75,6 +81,16 @@ class BenchmarkMetrics: std_e2el_ms: float percentiles_e2el_ms: list[tuple[float, float]] +@dataclass +class EmbedBenchmarkMetrics: + completed: int + total_input: int + request_throughput: float + total_token_throughput :float + mean_e2el_ms: float + std_e2el_ms: float + median_e2el_ms: float + percentiles_e2el_ms: float def _get_current_request_rate( ramp_up_strategy: Optional[Literal["linear", "exponential"]], @@ -146,11 +162,11 @@ async def get_request( delay_ts = [] for request_index, request in enumerate(input_requests): current_request_rate = _get_current_request_rate(ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate) + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -160,7 +176,7 @@ async def get_request( # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) - + # Calculate the cumulative delay time from the first sent out requests. for i in range(1, len(delay_ts)): delay_ts[i] += delay_ts[i - 1] @@ -170,11 +186,11 @@ async def get_request( # logic would re-scale delay time to ensure the final delay_ts # align with target_total_delay_s. # - # NOTE: If we simply accumulate the random delta values - # from the gamma distribution, their sum would have 1-2% gap + # NOTE: If we simply accumulate the random delta values + # from the gamma distribution, their sum would have 1-2% gap # from target_total_delay_s. The purpose of the following logic is to - # close the gap for stablizing the throughput data - # from different random seeds. + # close the gap for stablizing the throughput data + # from different random seeds. target_total_delay_s = total_requests / request_rate normalize_factor = target_total_delay_s / delay_ts[-1] delay_ts = [delay * normalize_factor for delay in delay_ts] @@ -189,6 +205,51 @@ async def get_request( yield request, request_rates[request_index] +def calculate_metrics_for_embeddings( + outputs: list[RequestFuncOutput], + dur_s: float, + selected_percentiles: list[float] +) -> EmbedBenchmarkMetrics: + """Calculate the metrics for the embedding requests. + + Args: + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + selected_percentiles: The percentiles to select. + + Returns: + The calculated benchmark metrics. + """ + total_input = 0 + completed = 0 + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + e2els.append(outputs[i].latency) + completed += 1 + total_input += outputs[i].prompt_len + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = EmbedBenchmarkMetrics( + completed=completed, + total_input=total_input, + request_throughput=completed / dur_s, + total_token_throughput=total_input / dur_s, + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles + ], + ) + return metrics + + def calculate_metrics( input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], @@ -334,8 +395,16 @@ async def benchmark( ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): + task_type = ( + TaskType.EMBEDDING + if api_url.endswith("/v1/embeddings") + else TaskType.GENERATION + ) if endpoint_type in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + if task_type == TaskType.EMBEDDING: + request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] + else: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] else: raise ValueError(f"Unknown endpoint_type: {endpoint_type}") @@ -421,8 +490,8 @@ async def benchmark( if profile_output.success: print("Profiler started") - distribution = ("Poisson process" if burstiness == 1.0 - else "Gamma distribution") + distribution = ("Poisson process" if burstiness == 1.0 + else "Gamma distribution") if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") @@ -449,7 +518,7 @@ async def benchmark( session=session, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, + return await request_func(request_func_input=request_func_input, session=session, pbar=pbar) @@ -478,11 +547,12 @@ async def benchmark( "timestamp": timestamp }) last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content = ( + prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, + request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: @@ -498,7 +568,8 @@ async def benchmark( logprobs=logprobs, multi_modal_content=mm_content, ignore_eos=ignore_eos, - extra_body=extra_body) + extra_body=extra_body, + request_id=request_id,) tasks.append( asyncio.create_task( limited_request_func(request_func_input=request_func_input, @@ -511,14 +582,22 @@ async def benchmark( benchmark_duration = time.perf_counter() - benchmark_start_time - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) + if task_type == TaskType.GENERATION: + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + else: + metrics = calculate_metrics_for_embeddings( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + ) + actual_output_lens = 0 print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) @@ -527,39 +606,55 @@ async def benchmark( max_concurrency)) if request_rate != float('inf'): print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", - request_rate )) + request_rate)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) + if isinstance(metrics, BenchmarkMetrics): + print("{:<40} {:<10}".format( + "Total generated tokens:", metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) if goodput_config_dict: print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) + if isinstance(metrics, BenchmarkMetrics): + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput)) - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": - metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } + if isinstance(metrics, BenchmarkMetrics): + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + else: + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "request_throughput": metrics.request_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "errors": [output.error for output in outputs], + } if rps_change_events: result["rps_change_events"] = rps_change_events @@ -596,10 +691,11 @@ async def benchmark( value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") + if task_type == TaskType.GENERATION: + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric( + "tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") print("=" * 50) @@ -730,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -741,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -865,6 +961,14 @@ def add_cli_args(parser: argparse.ArgumentParser): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) + sampling_group = parser.add_argument_group("sampling parameters") sampling_group.add_argument( @@ -958,6 +1062,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def main(args: argparse.Namespace) -> dict[str, Any]: return asyncio.run(main_async(args)) + async def main_async(args: argparse.Namespace) -> dict[str, Any]: print(args) random.seed(args.seed) @@ -1036,32 +1141,32 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: gc.freeze() benchmark_result = await benchmark( - endpoint_type=args.endpoint_type, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ready_check_timeout_sec=args.ready_check_timeout_sec, - ) + endpoint_type=args.endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + ready_check_timeout_sec=args.ready_check_timeout_sec, + ) # Save config and results to json result_json: dict[str, Any] = {} @@ -1088,7 +1193,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Traffic result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + < float("inf") else "inf") result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -1122,7 +1227,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.max_concurrency is not None else "") label = label or endpoint_type if args.ramp_up_strategy is not None: - file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa else: file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: @@ -1139,4 +1244,4 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) - return result_json \ No newline at end of file + return result_json diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 0c19fa6dcfdd230d752b5839d128896e8b427636..f022a55e625f58055b741d7ef58d11d9d6fe4dac 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -434,6 +434,14 @@ def validate_args(args): if args.backend == "mii" and args.tokenizer != args.model: raise ValueError( "Tokenizer must be the same as the model for MII backend.") + + # --data-parallel is not supported currently. + # https://github.com/vllm-project/vllm/issues/16222 + if args.data_parallel_size > 1: + raise ValueError( + "Data parallel is not supported in offline benchmark, " + "please use benchmark serving instead" + ) def add_cli_args(parser: argparse.ArgumentParser): diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index ce4e50a2b02d11ecb3658925eb1c45e09e5feac1..03c0a0e811c8812d46a62d07cd77bbd47657e6e7 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -1,54 +1,155 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, register_replacement) +from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + +SILU_MUL_OP = torch.ops._C.silu_and_mul.default -def silu_mul_pattern_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) - return at2[1] +# FUSED_OPS: dict[QuantKey, OpOverload] = { +# kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 +# } +# silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr( +# torch.ops._C, "silu_and_mul_nvfp4_quant")) +# if silu_and_mul_nvfp4_quant_supported: +# FUSED_OPS[ +# kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -def silu_mul_replacement_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, - result=result, - input=input, - scale=scale) - return at[1] +class ActivationQuantPattern(ABC): + """ + The base class for Activation+Quant fusions. + Should not be used directly. + """ + def __init__( + self, + quant_key: QuantKey, + ): + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype -def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + assert self.quant_key in QUANT_OPS, \ + f"unsupported quantization scheme {self.quant_key}" + self.QUANT_OP = QUANT_OPS[self.quant_key] + assert self.quant_key in FUSED_OPS, \ + f"unsupported fusion scheme {self.quant_key}" + self.FUSED_OP = FUSED_OPS[self.quant_key] -def empty_fp8(*args, **kwargs): - fp8 = current_platform.fp8_dtype() - return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + def empty_quant(self, *args, **kwargs): + kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + @abstractmethod + def register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError -def empty_fp32(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + +class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Fp8StaticQuant Pattern + """ + + def __init__(self, symmetric: bool = True): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(quant_key) + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(SILU_MUL_OP, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale) + return at2[1] + + def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + scale=scale) + return at[1] + + inputs = [ + self.empty_quant(5, 4), # result + empty_bf16(5, 4), # result_silu_mul + empty_bf16(5, 4), # input + empty_fp32(1, 1) # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) + + +class SiluMulNvfp4QuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Nvfp4Quant Pattern + """ + + def __init__(self): + super().__init__(kNvfp4Quant) + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(result: torch.Tensor, output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(SILU_MUL_OP, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(self.QUANT_OP, + output=result, + input=at1[1], + output_scale=output_scale, + input_scale=scale) + return at2[1], at2[2] + + def replacement(result: torch.Tensor, output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + result_block_scale=output_scale, + input=input, + input_global_scale=scale) + return at[1], at[2] + + inputs = [ + self.empty_quant(5, 32), # result + empty_i32(128, 4), # output_scale + empty_bf16(5, 64), # result_silu_mul + empty_bf16(5, 64), # input + empty_fp32(1, 1) # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) class ActivationQuantFusionPass(VllmInductorPass): @@ -61,21 +162,19 @@ class ActivationQuantFusionPass(VllmInductorPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="activation_quant_fusion_pass") - inputs = [ - empty_fp8(5, 4), # Quant output - empty_bf16(5, 4), # Silu_and_mul output - empty_bf16(5, 4), # Input - empty_fp32(1, 1) # Scale - ] - register_replacement(silu_mul_pattern_static, - silu_mul_replacement_static, inputs, fwd_only, - self.patterns) + pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern() + pattern_silu_mul_fp8.register(self.patterns) + + if silu_and_mul_nvfp4_quant_supported: + pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() + pattern_silu_mul_nvfp4.register(self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin() @@ -87,3 +186,8 @@ class ActivationQuantFusionPass(VllmInductorPass): self.dump_graph(graph, "after_act_quant_fusion") self.end_and_log() + + def uuid(self): + return VllmInductorPass.hash_source(self, ActivationQuantPattern, + SiluMulFp8StaticQuantPattern, + SiluMulNvfp4QuantPattern) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 059e7a3b29761c7b730c757eb68136fd043f514e..3361b65a9b8852c5d18aa0656a913fc0bb522fe8 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -271,7 +271,7 @@ def split_graph(graph: fx.GraphModule, outputs.append( SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) - # sort by intetger graph_id, rather than string name + # sort by integer graph_id, rather than string name outputs.sort(key=lambda x: x.graph_id) return split_gm, outputs @@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): def __init__(self, module: torch.fx.GraphModule, compile_submod_names: list[str], vllm_config: VllmConfig, - graph_pool, vllm_backend: "VllmBackend"): + vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.vllm_config = vllm_config self.vllm_backend = vllm_backend # When True, it annoyingly dumps the torch.fx.Graph on errors. @@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): runnable=piecewise_backend, vllm_config=self.vllm_config, runtime_mode=CUDAGraphMode.PIECEWISE, - graph_pool=self.graph_pool, cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, @@ -405,7 +403,6 @@ class VllmBackend: vllm_config: VllmConfig compilation_config: CompilationConfig - graph_pool: Any _called: bool = False # the graph we compiled graph: fx.GraphModule @@ -427,19 +424,12 @@ class VllmBackend: # if the model is initialized with a non-empty prefix, # then usually it's enough to use that prefix, - # e.g. launguage_model, vision_model, etc. + # e.g. language_model, vision_model, etc. # when multiple parts are initialized as independent # models, we need to use the model_tag to distinguish # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global_graph_pool = current_platform.get_global_graph_pool() - - # TODO: in the future, if we want to use multiple - # streams, it might not be safe to share a global pool. - # only investigate this when we use multiple streams - self.graph_pool = global_graph_pool - # Passes to run on the graph post-grad. self.post_grad_pass_manager = PostGradPassManager() @@ -484,7 +474,7 @@ class VllmBackend: factors = [] # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affects the computation graph. + # VLLM_PP_LAYER_PARTITION will affect the computation graph. env_hash = envs.compute_hash() factors.append(env_hash) @@ -586,7 +576,7 @@ class VllmBackend: # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.vllm_config, self.graph_pool, + self.vllm_config, self).run(*example_inputs) graph_path = os.path.join(local_cache_dir, "computation_graph.py") diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 1c3f52c533b131934dbc520cb4103251c2a6d8c1..161d066ce9fb8e19b7b4e2c46376c3b5adf671d0 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol): """ def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs): + runtime_mode: CUDAGraphMode, **kwargs): """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol): graph runtime. See CUDAGraphMode in vllm/config.py. Note that only the subset enum `NONE`, `PIECEWISE` and `FULL` are used as concrete runtime mode for cudagraph dispatching. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. Keyword Args: kwargs: Additional keyword arguments for platform-specific configurations. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 6ae50245ed3a896876e1b1f92d54dce385d0f596..7a99aaff707dcff30625ea33687ce54a887b9fb3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -18,6 +19,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass FP8_DTYPE = current_platform.fp8_dtype() @@ -348,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern): class AsyncTPPass(VllmInductorPass): + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -401,6 +404,18 @@ if flashinfer_comm is not None: 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB } + + try: + _FI_MAX_SIZES.update({ + int(k): int(float(v) * MiB) + for k, v in + envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() + }) + except Exception as e: + raise ValueError( + "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + + str(e)) from e + # opt for a more conservative default value # when world size is not in _FI_MAX_SIZES _DEFAULT_FI_MAX_SIZE = MiB // 2 @@ -465,7 +480,8 @@ if flashinfer_comm is not None: quant_out=quant_out, scale_out=scale_out, # in vllm we only support swizzled layout - layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED, + layout_code=flashinfer_comm.QuantizationSFLayout. + SWIZZLED_128x4, scale_factor=scale_factor, ) else: @@ -1107,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass): # in fallback path, when we don't use flashinfer fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + self.register_patterns() + + @enable_fake_mode + def register_patterns(self): for epsilon in [1e-5, 1e-6]: AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 65a38197ad4e28c2daecc7a34fc1b99e6c829fbb..e233f959c0a4a2ad26e09f2b8578185044a29db2 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -67,11 +67,9 @@ class CUDAGraphWrapper: runnable: Callable, vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, - graph_pool: Any = None, cudagraph_options: Optional[CUDAGraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config - self.graph_pool = graph_pool self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config @@ -81,8 +79,10 @@ class CUDAGraphWrapper: # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't # need to initialize a CUDAGraphWrapper. assert self.runtime_mode != CUDAGraphMode.NONE - if self.graph_pool is None: - self.graph_pool = current_platform.get_global_graph_pool() + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = current_platform.get_global_graph_pool() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index c0acfca918d552e1a4140f79966d3b197660f705..cc436bf1ee568a202055203b63648897c9f52b78 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -54,6 +54,14 @@ def _should_ignore_torch_compile(cls) -> bool: return getattr(cls, IGNORE_COMPILE_KEY, False) +@overload +def support_torch_compile( + *, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, +) -> Callable[[_T], _T]: + ... + + @overload def support_torch_compile( *, @@ -71,6 +79,7 @@ def support_torch_compile( cls: Optional[_T] = None, *, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -120,6 +129,11 @@ def support_torch_compile( NOTE: if an argument is `None`, it should always be passed as `None` during the lifetime of the model, otherwise, it cannot be captured as a single computation graph. + + `enable_if` is a function that takes a `VllmConfig` object as input and + returns a boolean value indicating whether to compile the model or not. + This is useful if you want to compile the model only when certain + conditions are met. """ def cls_decorator_helper(cls: _T) -> _T: @@ -151,7 +165,8 @@ def support_torch_compile( if k not in sig.parameters: raise ValueError( f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, inferred_dynamic_arg_dims) + return _support_torch_compile(cls, inferred_dynamic_arg_dims, + enable_if) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -164,6 +179,7 @@ def support_torch_compile( def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, Union[int, list[int]]], + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -184,13 +200,14 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config + enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() or _should_ignore_torch_compile( - self.__class__) + self.__class__) or not enable_compile if self.do_not_compile: return @@ -273,8 +290,24 @@ def _support_torch_compile( code.co_filename) return inline_call(parent, func, args, kwargs) + # Disable the C++ compilation of symbolic shape guards. C++-fication + # of symbolic shape guards can improve guard overhead. But, since + # vllm skip guards anyways, setting this flag to False can improve + # compile time. + dynamo_config_patches = {} + try: + _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards + dynamo_config_patches[ + "enable_cpp_symbolic_shape_guards"] = False + except AttributeError: + # Note: this config is not available in torch 2.6, we can skip + # if the config doesn't exist + logger.debug( + "enable_cpp_symbolic_shape_guards config not available") + with patch.object(InliningInstructionTranslator, 'inline_call', - patched_inline_call): + patched_inline_call), torch._dynamo.config.patch( + **dynamo_config_patches): output = self.compiled_callable(*args, **kwargs) return output diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 48f32c81d7b86c7d28832ae7f22e2674f3dd3927..d99e3b7cfcfeaa2e9d07df3b9e8518815d972d14 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -9,6 +9,7 @@ import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger +from vllm.platforms import current_platform from .fx_utils import is_func from .vllm_inductor_pass import VllmInductorPass @@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass): """ def __call__(self, graph: torch.fx.Graph): + # XPU does not support auto-functionalization yet. + # Will enable this when switch to vllm-xpu-kernels. + if current_platform.is_xpu(): + logger.debug("XPU platform does not support fix functionalization" + "pass currently.") + return + self.begin() self.dump_graph(graph, "before_fix_functionalization") @@ -89,6 +97,15 @@ class FixFunctionalizationPass(VllmInductorPass): # node, # mutated_args, # args=('result', 'input', 'scale')) + # elif hasattr( + # torch.ops._C, "silu_and_mul_nvfp4_quant" + # ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: + # mutated_args = {1: 'result', 2: 'result_block_scale'} + # self.defunctionalize(graph, + # node, + # mutated_args, + # args=('result', 'result_block_scale', + # 'input', 'input_global_scale')) else: continue # skip the count diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 41cefb3debc4a853c48a8cedd5fa8dc07c9d8664..273b87a696054876474236cc6203023c58547966 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,15 +12,18 @@ from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe +from .inductor_pass import enable_fake_mode from .multi_output_match import MultiOutputMatch from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def empty_bf16(*args, **kwargs): @@ -31,41 +34,12 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") -RMS_OP = torch.ops._C.rms_norm.default -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default - - -class QuantKey(NamedTuple): - """ - Named tuple for identifying the type of quantization. - dtype: quantized data type - static: static quantization if True, dynamic if False - group_shape: quantization group shape - symmetric: symmetric if True, asymmetric if False - - TODO(luka) use QuantDescriptor once standardized: - https://github.com/vllm-project/vllm/issues/8913 +def empty_i32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") - """ - dtype: torch.dtype - static: bool - group_shape: GroupShape - symmetric: bool = True - def __str__(self): - group_shape = ('per_tensor' - if self.group_shape == GroupShape.PER_TENSOR else - ('per_token' if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape))) - - return (f"QuantKey({'static' if self.static else 'dynamic'}," - f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape}," - f"{'a' if not self.symmetric else ''}symmetric)") - - -# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) -# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) -# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { # kFp8StaticTensorSym: @@ -75,6 +49,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { # kFp8DynamicTokenSym: # torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[ + kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -187,11 +164,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): @@ -244,11 +219,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, @@ -337,10 +310,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=False, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) @@ -435,10 +408,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=True, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) @@ -556,6 +529,7 @@ class FusionPass(VllmInductorPass): cls._instance.pass_config = config.compilation_config.pass_config return cls._instance + @enable_fake_mode def __init__(self, config: VllmConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index a40a8caf34a883a45822bc82f961adef268ac7a1..3095f17110fdecc3c3fd23b96e2b763715058d86 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -1,45 +1,52 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass -from torch._subclasses.fake_tensor import (FakeTensorMode, - unset_fake_temporarily) from vllm.attention import Attention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from vllm.utils import round_up -from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default -class AttentionStaticQuantPattern: +class AttentionQuantPattern(ABC): + """ + The base class for Attn+Quant fusions. + Should not be used directly. + """ def __init__( self, - layer_name: str, - num_heads: int, - head_size: int, - quant_dtype: torch.dtype, - symmetric=True, + layer: Attention, + quant_key: QuantKey, ): - self.layer_name = layer_name - self.num_heads = num_heads - self.head_size = head_size - self.quant_dtype = quant_dtype - self.quant_key = QuantKey(dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric) + self.layer = layer + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.head_size = layer.head_size + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] @@ -48,31 +55,64 @@ class AttentionStaticQuantPattern: kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) - def register_if_supported(self, pm_pass: PatternMatcherPass, - layer: Attention): - if layer.impl.fused_output_quant_supported(self.quant_dtype, - self.quant_key.static, - self.quant_key.group_shape): + @staticmethod + def wrap_trace_fn(process_fx, trace_fn): + + def wrapped(*args, **kwargs): + return process_fx(trace_fn(*args, **kwargs)) + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) + return gm + + def register_if_supported(self, pm_pass: PatternMatcherPass): + if self.layer.impl.fused_output_quant_supported(self.quant_key): self._register(pm_pass) + @abstractmethod + def _register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError + + +class AttentionFp8StaticQuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Fp8StaticQuant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Fp8StaticQuant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__( + self, + layer: Attention, + symmetric: bool = True, + ): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(layer, quant_key) + def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_attn, - [-1, self.num_heads, self.head_size]) - at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, - output_scale=None) - attn_out_view = RESHAPE_OP(at1[1], - [-1, self.num_heads * self.head_size]) - + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, result=output_quant, input=attn_out_view, @@ -82,47 +122,116 @@ class AttentionStaticQuantPattern: def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_quant, - [-1, self.num_heads, self.head_size]) - + # attn output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size], + 0.0, + dtype=self.quant_dtype, + device=q.device) at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, - output_scale=scale) - + output_scale=scale, + output_block_scale=None) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) - # Need custom fake mode, otherwise tracing happens with real tensors. - # That would not work for the unified_attention custom op. - with unset_fake_temporarily(), FakeTensorMode(): - inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads * self.head_size), # attn_output - self.empty_quant(5, self.num_heads * - self.head_size), # quant_output - empty_fp32(1, 1) # scale - ] + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # attn_output + self.empty_quant(5, + self.num_heads * self.head_size), # quant_output + empty_fp32(1, 1) # scale + ] - def wrap_trace_fn(process_fx, trace_fn): + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) - def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) - return wrapped +class AttentionNvfp4QuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Nvfp4Quant. - def fx_view_to_reshape(gm: torch.fx.GraphModule): - from torch._inductor.fx_passes.post_grad import view_to_reshape - view_to_reshape(gm) - return gm + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Nvfp4Quant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__(self, layer: Attention): + super().__init__(layer, kNvfp4Quant) + + def _register(self, pm_pass: PatternMatcherPass): + + def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) + at2 = auto_functionalized(self.QUANT_OP, + output=output_quant, + input=attn_out_view, + output_scale=output_scale, + input_scale=input_scale) + output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) + return at2[1], output_scale_view - pm.register_replacement( - pattern, replacement, inputs, - wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) + def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + # attention output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size // 2], + 0.0, + dtype=self.quant_dtype, + device=q.device) + # attention output block scale + output_scale_view = torch.ops.aten.view.dtype( + output_scale, FP8_DTYPE) + at2 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view) + output = RESHAPE_OP(at2[1], + [-1, self.num_heads * self.head_size // 2]) + return output, at2[2] + + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # output_attn + self.empty_quant(5, self.num_heads * self.head_size // + 2), # output_quant + empty_i32(128, round_up(self.num_heads * self.head_size // 16, + 4)), # output_scale + empty_fp32(1, 1), # input_scale + ] + + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) class AttnFusionPass(VllmInductorPass): @@ -138,32 +247,42 @@ class AttnFusionPass(VllmInductorPass): support are attention kernels, which need to support fusing output quant. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) - self.static_fwd_ctx = config.compilation_config.static_forward_context self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") - for key, layer in self.static_fwd_ctx.items(): - pattern = AttentionStaticQuantPattern(key, layer.num_heads, - layer.head_size, - current_platform.fp8_dtype()) - pattern.register_if_supported(self.patterns, layer) - if len(self.static_fwd_ctx) == 0: + attn_layers = get_layers_from_vllm_config(config, Attention) + for layer_name, layer in attn_layers.items(): + pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8.register_if_supported(self.patterns) + + pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) + pattern_nvfp4.register_if_supported(self.patterns) + + if len(attn_layers) == 0: logger.warning( - "Attention + quant fusion is enabled, but " - "CompilationConfig.static_forward_context is empty. " - "Cannot access attention layers so no fusion " - "patterns were registered.") + "Attention + quant fusion is enabled, but no attention layers " + "were found in CompilationConfig.static_forward_context " + "so no fusion patterns were registered.") def __call__(self, graph: torch.fx.graph.Graph) -> None: self.begin() self.dump_graph(graph, "before_attn_fusion") count = self.patterns.apply(graph) + + # TODO: Move this to pass_manager.py after the fx graph broken issue + # has been resolved. + # see https://github.com/vllm-project/vllm/issues/23091 + graph.eliminate_dead_code() + logger.debug("Fused quantization onto %s attention nodes", count) self.dump_graph(graph, "after_attn_fusion") self.end_and_log() def uuid(self): - return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern) + return VllmInductorPass.hash_source(self, AttentionQuantPattern, + AttentionFp8StaticQuantPattern, + AttentionNvfp4QuantPattern) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 2a149c65b3877aee230f61241cb375a7b41cc2bc..e1b691df385d766cda1e955f35c3619213e3b14c 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import hashlib import inspect import json @@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union import torch from torch import fx +from torch._subclasses.fake_tensor import (FakeTensorMode, + unset_fake_temporarily) from vllm.utils import is_torch_equal_or_newer @@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass): def uuid(self) -> Any: return self._uuid + + +def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Applies a FakeTensorMode context. This is useful when you don't want to + create or run things with real tensors. + """ + + @functools.wraps(fn) + def fn_new(*args, **kwargs) -> Any: + with torch._guards.tracing( + None), unset_fake_temporarily(), FakeTensorMode(): + result = fn(*args, **kwargs) + + return result + + return fn_new diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 9047bf3cbf8e8f0b94df7c68c83a3509abaf0c7e..c46721ab2d74667a577bca4ee2ac412ea13de7e2 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -43,7 +43,7 @@ cudagraph_capturing_enabled: bool = True def validate_cudagraph_capturing_enabled(): - # used to monitor whether an cudagraph capturing is legal at runtime. + # used to monitor whether a cudagraph capturing is legal at runtime. # should be called before any cudagraph capturing. # if an illegal cudagraph capturing happens, raise an error. global cudagraph_capturing_enabled diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 79f28dd1ed07ead266be7b508dd2c424c446ba4b..a521fcba462512f7fb7aa13e91278703697024d6 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -8,13 +8,13 @@ from vllm.logger import init_logger from vllm.platforms import current_platform if current_platform.is_cuda_alike(): + from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import FusionPass from .fusion_attn import AttnFusionPass if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass -# from .activation_quant_fusion import ActivationQuantFusionPass from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index aadf54c827e82920a4237aa45587e66ccff5ca3e..dedd6ea311576c15d33d7e007bf92cf410f3e851 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import ( from vllm.logger import init_logger from vllm.platforms import current_platform +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass): performance. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index cd5531d3652f3fcb328006c32f665c5c5d3ed076..02ccad0911faf376e64f4e908ab82dd32db4f9ed 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -36,7 +36,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) -from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig +from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, + ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.utils import ConfigType, config from vllm.logger import init_logger @@ -199,7 +200,17 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: yield a, b a = b - cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + try: + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + except (OSError, KeyError, TypeError): + # HACK: Python 3.13+ workaround - set missing __firstlineno__ + # Workaround can be removed after we upgrade to pydantic==2.12.0 + with open(inspect.getfile(cls)) as f: + for i, line in enumerate(f): + if f"class {cls.__name__}" in line and ":" in line: + cls.__firstlineno__ = i + 1 + break + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] if not isinstance(cls_node, ast.ClassDef): raise TypeError("Given object was not a class.") @@ -254,8 +265,14 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "cpm", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", - "processed_logits"] +MMEncoderTPMode = Literal["weights", "data"] + + +class LogprobsMode(enum.Enum): + RAW_LOGITS = "raw_logits" + RAW_LOGPROBS = "raw_logprobs" + PROCESSED_LOGITS = "processed_logits" + PROCESSED_LOGPROBS = "processed_logprobs" @config @@ -359,12 +376,13 @@ class ModelConfig: specified in `SamplingParams`. The default value comes the default for the OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM.""" - logprobs_mode: LogprobsMode = "raw_logprobs" + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. - Raw means the values before applying logit processors, like bad words. - Processed means the values after applying such processors. + Raw means the values before applying any logit processors, like bad words. + Processed means the values after applying all processors, including + temperature and top_k/top_p. """ disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding @@ -427,7 +445,7 @@ class ModelConfig: from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """The size (in GiB) of the multi-modal processor cache, which is used to avoid re-processing past multi-modal inputs. @@ -436,6 +454,19 @@ class ModelConfig: `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. Set to `0` to disable this cache completely (not recommended).""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -478,6 +509,8 @@ class ModelConfig: logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None """One or more logits processors' fully-qualified class names or class definitions""" + io_processor_plugin: Optional[str] = None + """IOProcessor plugin name to load at model startup""" enable_chunked_prefill: Optional[bool] = None """If True, prefill requests can be chunked based @@ -854,22 +887,25 @@ class ModelConfig: def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self._model_info.supports_multimodal: + if (self.mm_encoder_tp_mode == "data" and + not self._model_info.supports_multimodal_encoder_tp_data): + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`.") + self.mm_encoder_tp_mode = "weights" + return MultiModalConfig( limit_per_prompt=self.limit_mm_per_prompt, media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, interleave_mm_strings=self.interleave_mm_strings, - skip_mm_profiling=self.skip_mm_profiling) + skip_mm_profiling=self.skip_mm_profiling, + ) return None - def set_mm_processor_cache_gb(self, value: int) -> None: - mm_config = self.get_multimodal_config() - - self.mm_processor_cache_gb = value - mm_config.mm_processor_cache_gb = value - def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1099,10 +1135,22 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ - "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc", - "slimquant_w4a8", "slimquant_w4a8_marlin" + "fp8", + "modelopt", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed-tensors", + "experts_int8", + "quark", + "modelopt_fp4", + "bitblas", + "gptq_bitblas", + "inc", + "petit_nvfp4", + "slimquant_w4a8", + "slimquant_w4a8_marlin" ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -1125,7 +1173,6 @@ class ModelConfig: # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). overrides = [ - "marlin", "bitblas", "gptq_marlin_24", "gptq_marlin", @@ -1136,6 +1183,7 @@ class ModelConfig: "slimquant_w4a8_marlin" "modelopt", "modelopt_fp4", + "petit_nvfp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides @@ -1470,7 +1518,8 @@ class ModelConfig: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp"): + or self.hf_config.model_type == "glm4_moe_mtp" + or self.hf_config.model_type == "ernie_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1670,29 +1719,8 @@ class ModelConfig: return self.multimodal_config is not None @property - def processor_return_mm_hashes(self) -> bool: - """Whether the multi-modal processor should output hashes.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - - @property - def enable_mm_processor_cache(self) -> bool: - """Whether the multi-modal processor cache should be enabled.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - - def get_mm_input_cache_gb(self) -> int: - mm_config = self.multimodal_config - if mm_config is None: - return 0 - - return envs.VLLM_MM_INPUT_CACHE_GIB + def is_multimodal_raw_input_only_model(self) -> bool: + return self._model_info.supports_multimodal_raw_input_only @property def is_cross_encoder(self) -> bool: @@ -1703,10 +1731,6 @@ class ModelConfig: def is_pp_supported(self) -> bool: return self._model_info.supports_pp - @property - def is_multimodal_raw_input_supported(self) -> bool: - return self._model_info.supports_multimodal_raw_input - @property def is_attention_free(self) -> bool: return self._model_info.is_attention_free @@ -1917,7 +1941,8 @@ class DeviceConfig: SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp"] + "mlp_speculator", "draft_model", "deepseek_mtp", + "ernie_mtp"] @config @@ -2050,6 +2075,16 @@ class SpeculativeConfig: "architectures": ["Glm4MoeMTPModel"] }) + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["ErnieMTPModel"] + }) + return hf_config + return hf_config def __post_init__(self): @@ -2068,8 +2103,8 @@ class SpeculativeConfig: if self.target_model_config and \ (self.target_model_config.hf_text_config.model_type \ == "deepseek_v3" or - self.target_model_config.hf_text_config.model_type \ - == "mimo"): + self.target_model_config.hf_text_config.model_type in + ("mimo","ernie4_5_moe")): # use the draft model from the same model: self.model = self.target_model_config.model elif self.method in ("ngram", "[ngram]"): @@ -2167,6 +2202,15 @@ class SpeculativeConfig: "one layer. Might need some code changes " \ "to support multiple layers." ) + elif (self.draft_model_config.hf_config.model_type == + "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" raise NotImplementedError( @@ -2386,7 +2430,7 @@ class SpeculativeConfig: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp") + return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") def __repr__(self) -> str: method = self.method @@ -2422,8 +2466,8 @@ class LoRAConfig: lora_dtype: Union[torch.dtype, LoRADType] = "auto" """Data type for LoRA. If auto, will default to base model dtype.""" lora_extra_vocab_size: int = 256 - """Maximum size of extra vocabulary that can be present in a LoRA adapter - (added to the base model vocabulary).""" + """(Deprecated) Maximum size of extra vocabulary that can be present in a + LoRA adapter. Will be removed in v0.12.0.""" lora_vocab_padding_size: ClassVar[int] = current_platform\ .get_lora_vocab_padding_size() @@ -2465,6 +2509,12 @@ class LoRAConfig: return hash_str def __post_init__(self): + # Deprecation warning for lora_extra_vocab_size + logger.warning( + "`lora_extra_vocab_size` is deprecated and will be removed " + "in v0.12.0. Additional vocabulary support for " + "LoRA adapters is being phased out.") + # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) @@ -2529,7 +2579,7 @@ class MultiModalConfig: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """ The size (in GiB) of the multi-modal processor cache, which is used to @@ -2540,6 +2590,22 @@ class MultiModalConfig: Set to `0` to disable this cache completely (not recommended). """ + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """ + Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP. + """ + interleave_mm_strings: bool = False """ Enable fully interleaved support for multimodal prompts. @@ -2547,7 +2613,7 @@ class MultiModalConfig: skip_mm_profiling: bool = False """ - When enabled, skips multimodal memory profiling and only profiles with + When enabled, skips multimodal memory profiling and only profiles with language backbone model during engine initialization. This reduces engine startup time but shifts the responsibility to users for @@ -2610,24 +2676,24 @@ class PoolerConfig: ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the embeddings outputs. + Whether to normalize the embeddings outputs. """ dimensions: Optional[int] = None """ - Reduce the dimensions of embeddings if model + Reduce the dimensions of embeddings if model support matryoshka representation. """ ## for classification models activation: Optional[bool] = None """ - Whether to apply activation function to the classification outputs. + Whether to apply activation function to the classification outputs. """ ## for reward models softmax: Optional[bool] = None """ - Whether to apply softmax to the reward outputs. + Whether to apply softmax to the reward outputs. """ step_tag_id: Optional[int] = None """ @@ -2653,9 +2719,9 @@ class PoolerConfig: max_embed_len: Optional[int] = None """ - Maximum input length allowed for embedding generation. When set, allows + Maximum input length allowed for embedding generation. When set, allows inputs longer than max_embed_len to be accepted for embedding models. - This parameter enables accepting long inputs without requiring + This parameter enables accepting long inputs without requiring VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds max_embed_len, it will be handled according to the original max_model_len validation logic. Defaults to None (i.e. set to max_model_len). @@ -3009,7 +3075,8 @@ def get_served_model_name(model: str, return served_model_name -GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"] +GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines", + "lm-format-enforcer"] @config @@ -3572,7 +3639,7 @@ class VllmConfig: if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") - if current_platform.is_cuda_alike(): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): # if cudagraph_mode is not explicitly set by users, set default # value if self.compilation_config.cudagraph_mode is None: diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 6e1cf9d0b4d8cfab5cff4efc44887124f2bec36d..060ed4cec10e612b2c0250b0995211a31cbeddcd 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -115,8 +115,8 @@ class CacheConfig: In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), some layers can skip tokens corresponding to prefill. This flag enables - attention metadata for eligible layers to be overriden with metadata - necessary for implementating this optimization in some models (e.g. Gemma3n) + attention metadata for eligible layers to be overridden with metadata + necessary for implementing this optimization in some models (e.g. Gemma3n) """ def compute_hash(self) -> str: @@ -145,12 +145,19 @@ class CacheConfig: self._verify_cache_dtype() self._verify_prefix_caching() + self._verify_kv_sharing_fast_prefill() def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info return {key: str(value) for key, value in self.__dict__.items()} + def _verify_kv_sharing_fast_prefill(self) -> None: + if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not supported " + "in V0 currently.") + @model_validator(mode='after') def _verify_args(self) -> Self: if self.cpu_offload_gb < 0: @@ -162,11 +169,6 @@ class CacheConfig: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") - if self.kv_sharing_fast_prefill: - logger.warning_once( - "--kv-sharing-fast-prefill is currently work in progress " - "and not functional yet (i.e. no prefill savings)") - return self def _verify_cache_dtype(self) -> None: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 56a2183f8e2c1b121ba5a21a9bbe241417ca06c6..5c3b220016360d460d947015c40d349712417cb9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -225,7 +225,8 @@ class CompilationConfig: # CudaGraph compilation cudagraph_mode: Optional[CUDAGraphMode] = None """ - The mode of the cudagraph. + The mode of the cudagraph: + - NONE, no cudagraph capture. - PIECEWISE. (v1 default) - FULL. @@ -336,6 +337,9 @@ class CompilationConfig: "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.mamba_mixer", + "vllm.short_conv", + "vllm.linear_attention", ] def compute_hash(self) -> str: @@ -382,13 +386,10 @@ class CompilationConfig: if pass_config_exclude: exclude["pass_config"] = pass_config_exclude - # The cast to string is necessary because Pydantic is mocked in docs - # builds and sphinx-argparse doesn't know the return type of decode() - return str( - TypeAdapter(CompilationConfig).dump_json( - self, - exclude=exclude, # type: ignore[arg-type] - exclude_unset=True).decode()) + return TypeAdapter(CompilationConfig).dump_json( + self, + exclude=exclude, # type: ignore[arg-type] + exclude_unset=True).decode() __str__ = __repr__ diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 45d65fb3cb61533a3e126003c5f54b204aa415ae..6985a3f7eb05595cdd9334f37a6e4570aa99dced 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -15,7 +15,7 @@ import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, get_open_port +from vllm.utils import cuda_device_count_stateless, get_open_ports_list if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv @@ -32,6 +32,31 @@ logger = init_logger(__name__) DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] +@config +@dataclass +class EPLBConfig: + """Configuration for Expert Parallel Load Balancing (EP).""" + + window_size: int = 1000 + """Window size for expert load recording.""" + step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `lb_window_size` steps will be used for rearranging experts. + """ + + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + + log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + + @config @dataclass class ParallelConfig: @@ -75,22 +100,24 @@ class ParallelConfig: """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False """Enable expert parallelism load balancing for MoE layers.""" - num_redundant_experts: int = 0 - """Number of redundant experts to use for expert parallelism.""" - eplb_window_size: int = 1000 - """Window size for expert load recording.""" - eplb_step_interval: int = 3000 - """ - Interval for rearranging experts in expert parallelism. - - Note that if this is greater than the EPLB window size, only the metrics - of the last `eplb_window_size` steps will be used for rearranging experts. - """ - eplb_log_balancedness: bool = False - """ - Log the balancedness each step of expert parallelism. - This is turned off by default since it will cause communication overhead. - """ + eplb_config: EPLBConfig = field(default_factory=EPLBConfig) + """Expert parallelism configuration.""" + num_redundant_experts: Optional[int] = None + """`num_redundant_experts` is deprecated and has been replaced with + `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. + Please use `eplb_config.num_redundant_experts` instead.""" + eplb_window_size: Optional[int] = None + """`eplb_window_size` is deprecated and has been replaced with + `eplb_config.window_size`. This will be removed in v0.12.0. + Please use `eplb_config.window_size` instead.""" + eplb_step_interval: Optional[int] = None + """`eplb_step_interval` is deprecated and has been replaced with + `eplb_config.step_interval`. This will be removed in v0.12.0. + Please use `eplb_config.step_interval` instead.""" + eplb_log_balancedness: Optional[bool] = None + """`eplb_log_balancedness` is deprecated and has been replaced with + `eplb_config.log_balancedness`. This will be removed in v0.12.0. + Please use `eplb_config.log_balancedness` instead.""" max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model @@ -109,7 +136,8 @@ class ParallelConfig: placement_group: Optional[PlacementGroup] = None """ray distributed model workers placement group.""" - distributed_executor_backend: Optional[Union[DistributedExecutorBackend, + distributed_executor_backend: Optional[Union[str, + DistributedExecutorBackend, type[ExecutorBase]]] = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product @@ -137,9 +165,10 @@ class ParallelConfig: rank: int = 0 """Global rank in distributed setup.""" - enable_multimodal_encoder_data_parallel: bool = False - """ Use data parallelism instead of tensor parallelism for vision encoder. - Only support LLama4 for now""" + _data_parallel_master_port_list: list[int] = field(default_factory=list) + """List of open port auto-queried for data parallel messaging. + Set to be private as it's not intended to be configured by users. + """ @property def world_size_across_dp(self) -> int: @@ -153,11 +182,15 @@ class ParallelConfig: processes that is related to data parallelism, e.g. both in the worker and in the engine, which can live in different processes. To avoid port conflicts, we - increment the port number each time we need to initialize a - new process group related to data parallelism. + pop a new port from the prepared port list each time we need to + initialize a new process group related to data parallelism. """ - answer = self.data_parallel_master_port - self.data_parallel_master_port += 1 + if self._data_parallel_master_port_list: + answer = self._data_parallel_master_port_list.pop() + else: + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer def stateless_init_dp_group(self) -> ProcessGroup: @@ -241,6 +274,38 @@ class ParallelConfig: return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Forward deprecated fields to their new location + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = ( + self.num_redundant_experts) + logger.warning_once( + "num_redundant_experts is deprecated and has been replaced " + "with eplb_config.num_redundant_experts. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + logger.warning_once( + "eplb_window_size is deprecated and has been replaced " + "with eplb_config.window_size. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + logger.warning_once( + "eplb_step_interval is deprecated and has been replaced " + "with eplb_config.step_interval. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + logger.warning_once( + "eplb_log_balancedness is deprecated and has been replaced " + "with eplb_config.log_balancedness. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + + # Continue with the rest of the initialization self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size @@ -251,7 +316,10 @@ class ParallelConfig: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. - self.data_parallel_master_port = get_open_port() + if not self._data_parallel_master_port_list: + self._data_parallel_master_port_list = get_open_ports_list(5) + self.data_parallel_master_port = \ + self._data_parallel_master_port_list.pop() if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( @@ -279,10 +347,10 @@ class ParallelConfig: raise ValueError( "Expert parallelism load balancing is only supported on " "CUDA devices now.") - if self.num_redundant_experts < 0: + if self.eplb_config.num_redundant_experts < 0: raise ValueError( "num_redundant_experts must be non-negative, but got " - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if not self.enable_expert_parallel: raise ValueError( "enable_expert_parallel must be True to use EPLB.") @@ -293,10 +361,10 @@ class ParallelConfig: f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." ) else: - if self.num_redundant_experts != 0: + if self.eplb_config.num_redundant_experts != 0: raise ValueError( "num_redundant_experts should be used with EPLB." - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. @@ -342,23 +410,22 @@ class ParallelConfig: def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( isinstance(self.distributed_executor_backend, type) - and self.distributed_executor_backend.uses_ray) + and getattr(self.distributed_executor_backend, "uses_ray", False)) @model_validator(mode='after') def _verify_args(self) -> Self: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform - if self.distributed_executor_backend not in ( - "ray", "mp", "uni", - "external_launcher", None) and not (isinstance( + if self.distributed_executor_backend is not None and not isinstance( + self.distributed_executor_backend, str) and not (isinstance( self.distributed_executor_backend, type) and issubclass( self.distributed_executor_backend, ExecutorBase)): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " - "values are 'ray', 'mp' 'uni', 'external_launcher' or" - " custom ExecutorBase subclass.") + "values are 'ray', 'mp' 'uni', 'external_launcher', " + " custom ExecutorBase subclass or its import path.") if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index dae6ead04e9c94c6f3c76186063839f996b33a04..7d9b32cd4b674a0d1d74c4672844d4162d5e041a 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator): Args: absolute_id (int): The absolute block id for the block - in whole allocator. + in whole allocator. Returns: int: The zero-offset block id on certain device. diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 2913a01bf34a5ec7a721f6d11522a180a97b5830..a21d69323abbc288720bb6fa7d11aa476d408152 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): Args: num_blocks (int): The total number of blocks to manage. block_size (int): The size of each block in tokens. - block_ids(Optional[Iterable[int]], optional): An optional iterable of + block_ids (Optional[Iterable[int]], optional): An optional iterable of block IDs. If not provided, block IDs will be assigned sequentially from 0 to num_blocks - 1. """ diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 4ec5a775f465c3adb40e44e99336a06c9c229dc8..cbfa4d7ff3c4cc0a1b696b03a39dbf13c17faa99 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): with num_lookahead_slots. Args: - sequence_group (SequenceGroup): The sequence group to swap in. + seq_group (SequenceGroup): The sequence group to swap in. num_lookahead_slots (int): Number of lookahead slots used in speculative decoding, default to 0. @@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): Args: seq_group (SequenceGroup): The sequence group to swap out. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. Returns: bool: Whether it's possible to swap out current sequence group. @@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): swapping out the given sequence_group with num_lookahead_slots. Args: - sequence_group (SequenceGroup): The sequence group to swap out. + seq_group (SequenceGroup): The sequence group to swap out. Returns: List[Tuple[int, int]]: The mapping of swapping block from @@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): on to the 'device'. Args: - sequence_group (SequenceGroup): The sequence group to swap in/out. + seq_group (SequenceGroup): The sequence group to swap in/out. device (Device): device to swap the 'seq_group' on. status (SequenceStatus): The status of sequence which is needed for action. RUNNING for swap out and SWAPPED for swap in diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 7ec4768e90b1a537448f91bbb39135afa277e255..7a4a836ee348ecc7128e0f86cbd2012cd2c9ce98 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -76,7 +76,7 @@ class LRUEvictor(Evictor): that's recorded in the Block. If there are multiple blocks with the same last_accessed time, then the one with the largest num_hashed_tokens will be evicted. If two blocks each have the lowest last_accessed time and - highest num_hashed_tokens value, then one will be chose arbitrarily + highest num_hashed_tokens value, then one will be chosen arbitrarily """ # CLEANUP_THRESHOLD determines the maximum allowable size of the priority diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 63894e7f5dc8b8348b6d3cc5cb99a1429aaef261..d7864293e9647a7900aa30c4b5cd02f44b92f813 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -657,7 +657,7 @@ class Scheduler: `budget.num_batched_tokens` has not enough capacity to schedule all tokens. partial_prefill_metadata: information about the partial prefills - that are currently running + that are currently running Returns: SchedulerRunningOutputs. @@ -1591,7 +1591,6 @@ class Scheduler: encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, state=seq_group.state, - token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 942e866ed97ee3deee433d4b3c33e546e0363f09..7963fb15c4191afaac5f7937cf4b9b5a069368b0 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -152,8 +152,13 @@ class CuMemAllocator: self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback - def python_malloc_callback(self, allocation_handle: HandleType) -> None: + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" @@ -162,7 +167,7 @@ class CuMemAllocator: allocation_handle, self.current_tag) return - def python_free_callback(self, ptr: int) -> HandleType: + def _python_free_callback(self, ptr: int) -> HandleType: """ Internal method to look up the allocation data when memory is freed in the memory pool.""" @@ -212,9 +217,9 @@ class CuMemAllocator: def wake_up(self, tags: Optional[list[str]] = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - + :param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py similarity index 93% rename from vllm/distributed/device_communicators/custom_all_reduce_utils.py rename to vllm/distributed/device_communicators/all_reduce_utils.py index 7c6001e870392292b962a48b37be029e6e1a623a..5c64e7d5c4ba3f2faf6b0ca23cc60c6bd676105f 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless, logger = init_logger(__name__) +MiB = 1024 * 1024 +# Max size for each world size in case symmetric memory is available +# For different SM architectures +CUSTOM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: MiB // 2, # 512 KB + 8: MiB // 4, # 256 KB + }, + "10.0": { + 2: 2 * MiB, # 2 MB + 4: 2 * MiB, # 2 MB + 6: 2 * MiB, # 2 MB + 8: 2 * MiB, # 2 MB + } +} + +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB + }, + "10.0": { + 2: 8 * MiB, # 8 MB + 4: 32 * MiB, # 32 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + } +} + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 9e5aa4e4c2a891b3699aa9bb71c6eb4519481419..9131582eef7549559aa90fbd7fdedd57ade39a66 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -255,7 +255,7 @@ class DeviceCommunicatorBase: if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - module.quant_method.init_prepare_finalize() + module.quant_method.init_prepare_finalize(module) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 66d4940c9cec57aaa3e9ccd0cf256cca508305dc..eef3f9f75f9f1838848765e4b493a2c10fb37798 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase): PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce) + from vllm.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase): self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None + self.symm_mem_comm: Optional[SymmMemCommunicator] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase): # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase): out = ca_comm.custom_all_reduce(input_) assert out is not None return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) @@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase): dtype=input_tensor.dtype, device=input_tensor.device) - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase): device=input_tensor.device) if sizes is not None: - pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) else: - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 714df3916fc4bc64444bbef954ae9e7e747cf3e6..da9909dd9f625bdea0ab1d46b31f2fbbd2af0dbd 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,8 +11,8 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.custom_all_reduce_utils import ( - gpu_p2p_access_check) +from vllm.distributed.device_communicators.all_reduce_utils import ( + CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -114,7 +114,13 @@ class CustomAllreduce: # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - + device_capability = current_platform.get_device_capability( + ).as_version_str() + if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], + max_size) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..d907e1b833d040d1bf9e5d7a7452d8eeeab29487 --- /dev/null +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.all_reduce_utils import ( + SYMM_MEM_ALL_REDUCE_MAX_SIZES) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + symm_mem_available = True +except ImportError: + symm_mem_available = False + +logger = init_logger(__name__) + + +class SymmMemCommunicator: + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], + "10.0": [6, 8], + } + + def __init__(self, group: ProcessGroup, device: Union[int, str, + torch.device]): + self.disabled = True + + if not symm_mem_available: + return + + if not current_platform.is_cuda(): + logger.warning("SymmMemCommunicator: symmetric " + "memory is not available.") + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + torch.cuda.set_device(device) + self.dtype = torch.bfloat16 + self.device = device + self.group = group + self.world_size = dist.get_world_size(self.group) + self.device_capability = current_platform.get_device_capability( + ).as_version_str() + if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: + logger.warning( + "SymmMemCommunicator: Device capability %s not supported, " + "communicator is not available.", + self.device_capability, + ) + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability]: + logger.warning( + "SymmMemCommunicator: World size %d not supported, " + "communicator is not available.", + self.world_size, + ) + return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size] + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + if handle.multicast_ptr == 0: + logger.warning("SymmMemCommunicator: symmetric memory " + "multicast operations are not supported.") + return + self.disabled = False + + def should_use_symm_mem(self, inp: torch.Tensor): + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + return inp_size < self.max_size + + def all_reduce( + self, + inp: torch.Tensor, + *, + out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + if not self.should_use_symm_mem(inp): + return None + if out is None: + out = torch.empty_like(inp) + self.buffer[:inp.numel()].copy_(inp.view(-1)) + if self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability]: + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + else: + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + out.copy_(self.buffer[:inp.numel()].view(out.shape)) + return out diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index c60a7a7eb25cf88c41a7dcb92d498ced35d6f2ae..942dd67f065dcc38230213211fd92936a1f300a3 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from .base_device_communicator import DeviceCommunicatorBase @@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config( logger = init_logger(__name__) -if current_platform.is_tpu(): - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups) - - if USE_RAY: - from vllm.executor import ray_utils +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + if USE_RAY: + from vllm.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): @@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase): return xm.all_gather(input_, dim=dim) -try: +if USE_TPU_COMMONS: from tpu_commons.distributed.device_communicators import ( TpuCommunicator as TpuCommonsCommunicator) TpuCommunicator = TpuCommonsCommunicator # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TpuCommunicator") - pass diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index dee5ed7a28830097b16945ba9159a55ee816c13b..067315deb773dcb2a10eabee34ed418378a77863 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -7,8 +7,13 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +import vllm.envs as envs +from vllm.logger import init_logger + from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class XpuCommunicator(DeviceCommunicatorBase): @@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase): device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 979f2a06cec9f2e0ecf4c0efde57d2351a0c64cb..d5ab61473ab01f22078a424e37a22caa9ef11d5c 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -244,7 +244,7 @@ class EplbState: dtype=torch.int32, device=device, ) - expert_load_window_size = parallel_config.eplb_window_size + expert_load_window_size = parallel_config.eplb_config.window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, model.num_physical_experts), @@ -253,7 +253,7 @@ class EplbState: ) # Set the initial progress of rearrangement to 3/4 - eplb_step_interval = parallel_config.eplb_step_interval + eplb_step_interval = parallel_config.eplb_config.step_interval expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) @@ -409,12 +409,14 @@ class EplbState: self.expert_rearrangement_step = 0 self.rearrange(model) - def rearrange(self, - model: MixtureOfExperts, - is_profile: bool = False, - execute_shuffle: bool = True, - global_expert_load: Optional[torch.Tensor] = None, - rank_mapping: Optional[dict[int, int]] = None) -> None: + def rearrange( + self, + model: MixtureOfExperts, + is_profile: bool = False, + execute_shuffle: bool = True, + global_expert_load: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, + int]] = None) -> Optional[torch.Tensor]: """ Rearrange the experts according to the current load. """ @@ -548,6 +550,7 @@ class EplbState: " (profile) " if is_profile else " ", time_end - time_start, ) + return None @staticmethod def recv_state() -> tuple[torch.Tensor, torch.Tensor]: @@ -613,4 +616,4 @@ def _node_count_with_rank_mapping( if is_same_node and node_assignment[other_rank] == 0: node_assignment[other_rank] = next_node_id - return next_node_id \ No newline at end of file + return next_node_id diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 2d7935773dd9f03ed2e8033680c7dce97b6c2290..37f8f72fa90569da143931e78c2da5c62271f28a 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -40,16 +40,21 @@ class KVCacheEvent( """Base class for all KV cache-related events""" +MEDIUM_GPU = "GPU" + + class BlockStored(KVCacheEvent): block_hashes: list[int] parent_block_hash: Optional[int] token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index 349d3dfbd84fcfa34c1ea63050c79189d34d40be..39377aabcce3a93c04faecd157eb550f13eb6793 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -2,7 +2,7 @@ # Distributed KV cache transfer This folder implements distributed KV cache transfer across vLLM instances. -Currently the main usecase is for disaggregated prefilling. +Currently the main use case is for disaggregated prefilling. ## Abstractions @@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions: Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. -NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed communication service already supports key-value-based lookup (like redis or RDMA database). diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 07fcdecac62763e49f62213dbb2783ae27775d1f..2804003f5a7085b7fcb843a602655e7c03d78ea9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -19,6 +19,8 @@ The class provides the following primitives: Returns whether KV cache should be freed now or will be freed asynchronously and optionally returns KV transfer params. + take_events() - returns new KV events that were collected + by the connector since the last call. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -34,6 +36,7 @@ The class provides the following primitives: import enum from abc import ABC, abstractmethod +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch @@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig + from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -131,8 +135,8 @@ class KVConnectorBase_V1(ABC): Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). - Args: kv_caches: - dictionary of layer names, kv cache + Args: + kv_caches: dictionary of layer names, kv cache """ return @@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC): """ return False, None + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + return () + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d3f6a226dc72cbd778b61547d2787091a6fb23c5..65bcb4d93b1e13b443ce862e46ae9c4fb504c504 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from collections.abc import Iterable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1): return async_saves > 0, kv_txfer_params + def take_events(self) -> Iterable[KVCacheEvent]: + for c in self._connectors: + yield from c.take_events() + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4f51229ffbd261216417d39b907fe0b3a0ab02fb..6608d2a4a9e098766901a131d9c0e6afecaf81cd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -686,9 +686,6 @@ class NixlConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - _, first_kv_cache = next(iter(kv_caches.items())) - kv_elem_size = first_kv_cache.element_size() - if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -701,66 +698,16 @@ class NixlConnectorWorker: "host_xfer_buffer should not be initialized when " f"kv_buffer_device is {self.kv_buffer_device}") - # TODO(tms): Find a more robust way to detect and handle MLA - # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected - # KV memory layout is HND, as opposed to the default NHD. Note that it - # will only affects the strides. For MLA instead, we make require no - # such thing and resort to the standard layout. - use_mla = len(first_kv_cache.shape) == 3 - if self.device_type == "tpu": - assert not use_mla, f"{self.kv_buffer_device} does not support MLA." - assert self._use_pallas_v1, f"attn backend: {self.backend_name}" - # tpu (v1) kv shape per layer: - # (num_blocks, block_size, num_kv_heads * 2, head_size) - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads_x_2, head_dim = block_shape - self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim - elif self.device_type == "cuda": - assert use_mla == self.use_mla - # TODO (NickLucche) not compatible with hybrid allocator. - # Enforce check once it goes live, as a single kv layout - # is expected for xfers. - if use_mla: - # MLA case. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 2 # [block_size, latent_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim - else: - # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] - else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads, head_dim = block_shape[-3:] - # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim - assert block_size == self.block_size - else: - raise RuntimeError( - f"{self.device_type} ({self.backend_name}) is not supported.") - - # TODO(tms): self.block_len needs to be per-layer for sliding window, - # hybrid attn, etc - # block size in bytes - self.block_len = kv_elem_size * math.prod(block_shape) logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " - "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, - self.use_host_buffer, self.num_blocks, block_shape, - first_kv_cache.shape) - self.dst_num_blocks[self.engine_id] = self.num_blocks - self.device_kv_caches = kv_caches - kv_caches_base_addr = [] + "use_host_buffer: %s", self.use_mla, self.kv_buffer_device, + self.use_host_buffer) + caches_data = [] + # With hybrid allocator, layers can share a kv cache tensor + seen_base_addresses = [] + xfer_buffers = (self.host_xfer_buffers + if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -770,42 +717,35 @@ class NixlConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in xfer_buffers.values(): - # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla \ - or self._use_pallas_v1 or self._use_flashinfer \ - else cache_or_caches + split_k_and_v = not (self.use_mla or self._use_pallas_v1 + or self._use_flashinfer) + tensor_size_bytes = None + for layer_name, cache_or_caches in xfer_buffers.items(): + cache_list = cache_or_caches if split_k_and_v else [ + cache_or_caches + ] + for cache in cache_list: base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len - # NOTE: use tp_rank for device_id since multi-node TP - # is rarely used. - caches_data.append((base_addr, region_len, self.tp_rank, "")) - kv_caches_base_addr.append(base_addr) - self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.numel() * cache.element_size() + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, \ + "All kv cache tensors must have the same size" + caches_data.append( + (base_addr, tensor_size_bytes, self.tp_rank, "")) + + self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - if self.vllm_config.model_config.hf_config.model_type == "llama4": - from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) - llama4_config = self.vllm_config.model_config.hf_text_config - no_rope_layers = llama4_config.no_rope_layers - chunk_size = llama4_config.attention_chunk_size - chunk_block_size = math.ceil(chunk_size / self.block_size) - for layer_idx in range(self.num_layers): - # no_rope_layers[layer_idx] == 0 means NoPE (global) - # Any other value means RoPE (local chunked) - is_local_attention = no_rope_layers[layer_idx] != 0 - block_window = chunk_block_size if is_local_attention else None - self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) - assert len(self.block_window_per_layer) == self.num_layers - descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) @@ -813,9 +753,20 @@ class NixlConnectorWorker: logger.debug("Done registering descs") self._registered_descs.append(descs) + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.slot_size_bytes = self.block_len // self.block_size + if self._use_flashinfer: + assert self.slot_size_bytes % 2 == 0 + self.slot_size_bytes /= 2 + self.device_kv_caches = kv_caches + self.dst_num_blocks[self.engine_id] = self.num_blocks + # Register local/src descr for NIXL xfer. blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: + for base_addr in seen_base_addresses: # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # could create fewer, but then _get_block_descs_ids needs to @@ -836,6 +787,26 @@ class NixlConnectorWorker: self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) + # TODO(mgoin): Hybrid memory allocator is currently diabled for + # models with local attention (Llama 4). Can remove this once enabled. + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 5053a86887ac38c58015059600ab5df95598b02b..53ab2c5e36470365baff0a43c5d08aaea83cc62d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -30,27 +30,19 @@ logger = init_logger(__name__) class ReqMeta: # Request Id request_id: str - # Request tokens - token_ids: torch.Tensor - # Slot mappings, should have the same length as token_ids - slot_mapping: torch.Tensor + # Request block ids + block_ids: torch.Tensor + # Request num tokens + num_tokens: int @staticmethod def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], block_size: int) -> "ReqMeta": - valid_num_tokens = len(token_ids) - token_ids_tensor = torch.tensor(token_ids) block_ids_tensor = torch.tensor(block_ids) - num_blocks = block_ids_tensor.shape[0] - block_offsets = torch.arange(0, block_size) - slot_mapping = block_offsets.reshape((1, block_size)) + \ - block_ids_tensor.reshape((num_blocks, 1)) * block_size - slot_mapping = slot_mapping.flatten()[:valid_num_tokens] - return ReqMeta( request_id=request_id, - token_ids=token_ids_tensor, - slot_mapping=slot_mapping, + block_ids=block_ids_tensor, + num_tokens=len(token_ids), ) @@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1): return def inject_kv_into_layer( - dst_kv_cache_layer: torch.Tensor, - src_kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, + layer: torch.Tensor, + kv_cache: torch.Tensor, + block_ids: torch.Tensor, request_id: str, ) -> None: - """Inject the KV cache into the layer. + """ + Inject KV cache data into a given attention layer tensor. + + This function updates `layer` in-place with values from `kv_cache`, + handling different backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + If the number of provided block IDs does not match the number of KV + blocks, only the overlapping portion is updated, and a warning is + logged. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not - using MLA, [num_pages, page_size, xxx] otherwise. - src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] - otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape - [num_tokens]. - request_id (str): request id for log + layer (torch.Tensor): The attention layer KV tensor to update. + kv_cache (torch.Tensor): The KV cache tensor to inject. + block_ids (torch.Tensor): Indices of the blocks to update. + request_id (str): Request identifier used for logging. + + Returns: + None. The function modifies `layer` in-place. """ - dst_kv_cache_layer_shape = dst_kv_cache_layer.shape - if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()): - num_pages = dst_kv_cache_layer_shape[0] - page_size = dst_kv_cache_layer_shape[1] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 0) - num_token = src_kv_cache.shape[0] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + if (isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) + or layer.shape[1] == 2): # MLA or FlashInfer + num_block = kv_cache.shape[0] + self.check_tensors_except_dim(layer, kv_cache, 0) + if len(block_ids) == num_block: + layer[block_ids, ...] = kv_cache else: - dst_kv_cache_layer[slot_mapping[:num_token], - ...] = src_kv_cache + layer[block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) - - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) - else: - num_pages = dst_kv_cache_layer_shape[1] - page_size = dst_kv_cache_layer_shape[2] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 1) - num_token = src_kv_cache.shape[1] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) + + elif layer.shape[0] == 2: # FlashAttention + num_block = kv_cache.shape[1] + self.check_tensors_except_dim(layer, kv_cache, 1) + if len(block_ids) == num_block: + layer[:, block_ids, ...] = kv_cache else: - dst_kv_cache_layer[:, slot_mapping[:num_token], - ...] = src_kv_cache + layer[:, block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) - - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) # Get the metadata metadata: KVConnectorMetadata = \ @@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1): if kv_cache is None: continue - kv_cache_layer = kv_cache[ \ - forward_context.virtual_engine] + layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( request.request_id + "#" + layer_name) if kv_cache is None: - logger.warning("🚧src_kv_cache is None, %s", - request.request_id) + logger.warning("🚧kv_cache is None, %s", request.request_id) continue - inject_kv_into_layer(kv_cache_layer, kv_cache, - request.slot_mapping, request.request_id) + inject_kv_into_layer(layer, kv_cache, request.block_ids, + request.request_id) tensor = self.p2p_nccl_engine.recv_store.pop(request.request_id + "#" + layer_name, None) if tensor is not None: del tensor @@ -248,16 +233,46 @@ class P2pNcclConnector(KVConnectorBase_V1): assert self.p2p_nccl_engine is not None + def extract_kv_from_layer( + layer: torch.Tensor, + block_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Extract KV cache slices from a given attention layer tensor. + + This function handles multiple backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + Args: + layer (torch.Tensor): The KV cache from the attention layer. + block_ids (torch.Tensor): Indices of blocks to extract. + + Returns: + torch.Tensor: A tensor containing the extracted KV slices. + Returns None if the layout is unsupported. + """ + if (isinstance(attn_metadata, MLACommonMetadata) + or layer.shape[1] == 2): # MLA or FlashInfer + return layer[block_ids, ...] + + if layer.shape[0] == 2: # FlashAttention + return layer[:, block_ids, ...] + + return None + connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, P2pNcclConnectorMetadata) for request in connector_metadata.requests: request_id = request.request_id ip, port = self.parse_request_id(request_id, True) remote_address = ip + ":" + str(port + self._rank) - self.p2p_nccl_engine.send_tensor( - request_id + "#" + layer_name, kv_layer, remote_address, - request.slot_mapping, - isinstance(attn_metadata, MLACommonMetadata)) + + kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) + self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, + kv_cache, remote_address) def wait_for_save(self): if self.is_producer: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index b94f2296dcb360e1eef8deab632aa6603fdc78db..dfd95548c4632a3f0178bb5bdeb3c0a0871e73a5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -62,8 +62,6 @@ class SendQueueItem: tensor_id: str remote_address: str tensor: torch.Tensor - slot_mapping: torch.Tensor - is_mla: bool class P2pNcclEngine: @@ -202,8 +200,6 @@ class P2pNcclEngine: tensor_id: str, tensor: torch.Tensor, remote_address: typing.Optional[str] = None, - slot_mapping: torch.Tensor = None, - is_mla: bool = False, ) -> bool: if remote_address is None: with self.recv_store_cv: @@ -213,9 +209,7 @@ class P2pNcclEngine: item = SendQueueItem(tensor_id=tensor_id, remote_address=remote_address, - tensor=tensor, - slot_mapping=slot_mapping, - is_mla=is_mla) + tensor=tensor) if self.send_type == "PUT": return self.send_sync(item) @@ -433,9 +427,7 @@ class P2pNcclEngine: if item.remote_address not in self.socks: self.create_connect(item.remote_address) - with self.send_stream: - tensor = self.extract_kv_from_layer(item.is_mla, item.tensor, - item.slot_mapping) + tensor = item.tensor sock = self.socks[item.remote_address] comm, rank = self.comms[item.remote_address] @@ -548,21 +540,3 @@ class P2pNcclEngine: self._send_thread.join() if self._ping_thread is not None: self._ping_thread.join() - - @staticmethod - def extract_kv_from_layer( - is_mla: bool, - layer: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> torch.Tensor: - """Extract the KV cache from the layer. - Assume the shape of the layer is (2, num_pages, page_size, xxx) - if MLA is not used, and (num_pages, page_size, xxx) otherwise. - """ - if is_mla: - num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] - - num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py index 02e3bc6274f60bd1c0b476ca6f11221e7466abdb..b775276d4a846b055bed2906d75397eabc8ab845 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -99,8 +99,9 @@ class TensorMemoryPool: addr=self.base_address) self.free_lists[self.max_block_size][ initial_block.addr] = initial_block - logger.debug("TensorMemoryPool, base_address:", self.base_address, - self.base_address % self.max_block_size) + + logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d", + self.base_address, self.max_block_size) def allocate(self, size: int) -> int: """Allocates a memory block of at least the requested size. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e98c7e9e0814844dd76d704357cb5832a65e6cae..4d637b46bf2bf35f4c5abe70213c839cfbd4a2b8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,13 +9,13 @@ import dataclasses import functools import json import sys -import threading from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) +import huggingface_hub import regex as re import torch from pydantic import TypeAdapter, ValidationError @@ -25,22 +25,22 @@ import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, ConvertOption, DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, + DeviceConfig, DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, ModelConfig, ModelDType, - ModelImpl, MultiModalConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, - RunnerOption, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, TaskOption, TokenizerMode, - VllmConfig, get_attr_docs, get_field) + LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, + ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import is_interleaved +from vllm.transformers_utils.config import get_model_path, is_interleaved from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) @@ -153,9 +153,17 @@ def is_online_quantization(quantization: Any) -> bool: return quantization in ["inc"] +NEEDS_HELP = ( + "--help" in (argv := sys.argv) # vllm SUBCOMMAND --help + or (argv0 := argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND + or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND +) + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: - cls_docs = get_attr_docs(cls) + # Save time only getting attr docs if we're generating help text + cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} for field in fields(cls): # Get the set of possible types for the field @@ -173,7 +181,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the help text for the field name = field.name - help = cls_docs[name].strip() + help = cls_docs.get(name, "").strip() # Escape % for argparse help = help.replace("%", "%%") @@ -255,6 +263,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: def get_kwargs(cls: ConfigType) -> dict[str, Any]: """Return argparse kwargs for the given Config dataclass. + If `--help` or `mkdocs` are not present in the command line command, the + attribute documentation will not be included in the help output. + The heavy computation is cached via functools.lru_cache, and a deep copy is returned so callers can mutate the dictionary without affecting the cached version. @@ -291,7 +302,7 @@ class EngineArgs: # is intended for expert use only. The API may change without # notice. distributed_executor_backend: Optional[Union[ - DistributedExecutorBackend, + str, DistributedExecutorBackend, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size @@ -305,11 +316,12 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - num_redundant_experts: int = ParallelConfig.num_redundant_experts - eplb_window_size: int = ParallelConfig.eplb_window_size - eplb_step_interval: int = ParallelConfig.eplb_step_interval - eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness + num_redundant_experts: int = EPLBConfig.num_redundant_experts + eplb_window_size: int = EPLBConfig.window_size + eplb_step_interval: int = EPLBConfig.step_interval + eplb_log_balancedness: bool = EPLBConfig.log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -351,7 +363,9 @@ class EngineArgs: mm_processor_kwargs: Optional[Dict[str, Any]] = \ MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED - mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb + mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode + io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields enable_lora: bool = False @@ -436,16 +450,14 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - enable_multimodal_encoder_data_parallel: bool = \ - ParallelConfig.enable_multimodal_encoder_data_parallel + # DEPRECATED + enable_multimodal_encoder_data_parallel: bool = False logits_processors: Optional[list[Union[ str, type[LogitsProcessor]]]] = ModelConfig.logits_processors """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - # DEPRECATED - enable_prompt_adapter: bool = False kv_sharing_fast_prefill: bool = \ CacheConfig.kv_sharing_fast_prefill @@ -457,9 +469,18 @@ class EngineArgs: if isinstance(self.compilation_config, dict): self.compilation_config = CompilationConfig( **self.compilation_config) + if isinstance(self.eplb_config, dict): + self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() + # when use hf offline,replace model id to local model path + if huggingface_hub.constants.HF_HUB_OFFLINE: + model_id = self.model + self.model = get_model_path(self.model, self.revision) + logger.info( + "HF_HUB_OFFLINE is True, replace model_id [%s] " \ + "to model_path [%s]",model_id, self.model) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -508,6 +529,7 @@ class EngineArgs: model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", + choices=[f.value for f in LogprobsMode], **model_kwargs["logprobs_mode"]) model_group.add_argument("--disable-sliding-window", **model_kwargs["disable_sliding_window"]) @@ -559,6 +581,8 @@ class EngineArgs: **model_kwargs["override_attention_dtype"]) model_group.add_argument("--logits-processors", **model_kwargs["logits_processors"]) + model_group.add_argument("--io-processor-plugin", + **model_kwargs["io_processor_plugin"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -597,7 +621,7 @@ class EngineArgs: **guided_decoding_kwargs["disable_additional_properties"]) guided_decoding_group.add_argument( "--reasoning-parser", - # This choices is a special case because it's not static + # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), **guided_decoding_kwargs["reasoning_backend"]) @@ -657,14 +681,32 @@ class EngineArgs: **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--num-redundant-experts", - **parallel_kwargs["num_redundant_experts"]) - parallel_group.add_argument("--eplb-window-size", - **parallel_kwargs["eplb_window_size"]) - parallel_group.add_argument("--eplb-step-interval", - **parallel_kwargs["eplb_step_interval"]) - parallel_group.add_argument("--eplb-log-balancedness", - **parallel_kwargs["eplb_log_balancedness"]) + parallel_group.add_argument("--eplb-config", + **parallel_kwargs["eplb_config"]) + parallel_group.add_argument( + "--num-redundant-experts", + type=int, + help= + "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-window-size", + type=int, + help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-step-interval", + type=int, + help= + "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-log-balancedness", + action=argparse.BooleanOptionalAction, + help= + "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -680,7 +722,8 @@ class EngineArgs: **parallel_kwargs["worker_extension_cls"]) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", - **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + action="store_true", + deprecated=True) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -730,6 +773,8 @@ class EngineArgs: multimodal_group.add_argument("--disable-mm-preprocessor-cache", action="store_true", deprecated=True) + multimodal_group.add_argument( + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) @@ -869,12 +914,6 @@ class EngineArgs: parser.add_argument('--disable-log-stats', action='store_true', help='Disable logging statistics.') - parser.add_argument('--enable-prompt-adapter', - action='store_true', - deprecated=True, - help='[DEPRECATED] Prompt adapter has been ' - 'removed. Setting this flag to True or False' - ' has no effect on vLLM behavior.') return parser @@ -914,6 +953,14 @@ class EngineArgs: self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB + if self.enable_multimodal_encoder_data_parallel: + logger.warning( + "--enable-multimodal-encoder-data-parallel` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-encoder-tp-mode data` instead.") + + self.mm_encoder_tp_mode = "data" + return ModelConfig( model=self.model, hf_config_path=self.hf_config_path, @@ -952,6 +999,7 @@ class EngineArgs: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -961,6 +1009,7 @@ class EngineArgs: model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, + io_processor_plugin=self.io_processor_plugin, enable_chunked_prefill=self.enable_chunked_prefill, ) @@ -1022,11 +1071,11 @@ class EngineArgs: self.trust_remote_code, self.revision, self.code_revision, self.config_format) - # if loading a SpeculatorsConfig, load the specualtive_config + # if loading a SpeculatorsConfig, load the speculative_config # details from the config directly # no user input required / expected if isinstance(hf_config, SpeculatorsConfig): - # We create one since we dont create one + # We create one since we don't create one self.speculative_config = {} self.speculative_config[ "num_speculative_tokens"] = hf_config.num_lookahead_tokens @@ -1090,12 +1139,13 @@ class EngineArgs: # Set default arguments for V0 or V1 Engine. if use_v1: self._set_default_args_v1(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1 + # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1 if current_platform.is_cpu( ) and current_platform.get_cpu_architecture() in ( - CpuArchEnum.POWERPC, CpuArchEnum.ARM): + CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM): logger.info( - "Chunked prefill is not supported for ARM and POWER CPUs; " + "Chunked prefill is not supported for ARM and POWER " + "and S390X CPUs; " "disabling it for V1 backend.") self.enable_chunked_prefill = False else: @@ -1238,6 +1288,16 @@ class EngineArgs: "Currently, speculative decoding is not supported with " "async scheduling.") + # Forward the deprecated CLI args to the EPLB config. + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = self.num_redundant_experts + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1251,10 +1311,7 @@ class EngineArgs: data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, - num_redundant_experts=self.num_redundant_experts, - eplb_window_size=self.eplb_window_size, - eplb_step_interval=self.eplb_step_interval, - eplb_log_balancedness=self.eplb_log_balancedness, + eplb_config=self.eplb_config, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, @@ -1263,22 +1320,8 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, - enable_multimodal_encoder_data_parallel=self. - enable_multimodal_encoder_data_parallel, ) - if model_config.is_multimodal_model: - dp_supports_mm_processor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not dp_supports_mm_processor_cache - and model_config.mm_processor_cache_gb > 0): - logger.warning( - "Multi-modal processor cache is disabled because " - "it is not compatible with data parallelism when " - "there does not exist a one-to-one correspondance " - "between API and engine core processes.") - model_config.set_mm_processor_cache_gb(0) - speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -1408,21 +1451,20 @@ class EngineArgs: recommend_to_remove=True) return False - # Need at least Ampere for now (FA support required). - # Skip this check if we are running on a non-GPU platform, - # or if the device capability is not available - # (e.g. in a Ray actor without GPUs). + # Triton v3.3 has f16 conversion regression issue on Turing and Volta, + # which broke fp16 inference + # see: https://github.com/triton-lang/triton/issues/6698 if (current_platform.is_cuda() - and current_platform.get_device_capability() - and current_platform.get_device_capability().major < 8): - _raise_or_fallback(feature_name="Compute Capability < 8.0", - recommend_to_remove=False) + and not current_platform.has_device_capability(80) + and model_config.dtype == torch.float16): + _raise_or_fallback( + feature_name="Compute Capability < 8.0 with FP16", + recommend_to_remove=False) return False - # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": supported = current_platform.is_kv_cache_dtype_supported( - self.kv_cache_dtype) + self.kv_cache_dtype, model_config) int8_attention = self.kv_cache_dtype.startswith("int8") if int8_attention: supported = True @@ -1443,11 +1485,6 @@ class EngineArgs: recommend_to_remove=False) return False - # V1 mamba models are unoptimized. - if model_config.has_inner_state and _warn_or_fallback( - feature_name="Mamba"): - return False - # No Concurrent Partial Prefills so far. if (self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills @@ -1503,11 +1540,6 @@ class EngineArgs: ############################################################# # Experimental Features - allow users to opt in. - # Signal Handlers requires running in main thread. - if (threading.current_thread() != threading.main_thread() - and _warn_or_fallback("Engine in background thread")): - return False - if self.pipeline_parallel_size > 1: supports_pp = getattr(self.distributed_executor_backend, 'supports_pp', False) @@ -1556,8 +1588,7 @@ class EngineArgs: use_spec_decode = self.speculative_config is not None if (is_gpu and not use_sliding_window and not use_spec_decode - and not self.enable_lora - and model_config.runner_type != "pooling"): + and not self.enable_lora): self.enable_chunked_prefill = True logger.warning( "Chunked prefill is enabled by default for models " @@ -1575,10 +1606,6 @@ class EngineArgs: "OOM during the initial memory profiling phase, or result " "in low performance due to small KV cache size. Consider " "setting --max-model-len to a smaller value.", max_model_len) - elif (self.enable_chunked_prefill - and model_config.runner_type == "pooling"): - msg = "Chunked prefill is not supported for pooling models" - raise ValueError(msg) # if using prefix caching, we must set a hash algo if self.enable_prefix_caching: @@ -1765,7 +1792,7 @@ class AsyncEngineArgs(EngineArgs): def add_cli_args(parser: FlexibleArgumentParser, async_args_only: bool = False) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may - # adding a new kind of quantization method to --quantization argument or + # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 84ad2299b06556f8d958831e5627db30d1a897d2..9f9ad1854c3b6b02de2fd13008ea4b23dd209000 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -72,8 +72,8 @@ STOP_ITERATION = Exception() # Sentinel class AsyncStream: - """A stream of RequestOutputs or PoolingRequestOutputs for a request - that can be iterated over asynchronously via an async generator.""" + """A stream of RequestOutputs for a request that can be iterated over + asynchronously via an async generator.""" def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self.request_id = request_id @@ -81,8 +81,7 @@ class AsyncStream: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: + def put(self, item: Union[RequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -99,9 +98,7 @@ class AsyncStream: def finished(self) -> bool: return self._finished - async def generator( - self - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + async def generator(self) -> AsyncGenerator[RequestOutput, None]: try: while True: result = await self._queue.get() @@ -151,8 +148,7 @@ class RequestTracker: self.abort_request(rid, exception=exc) def process_request_output(self, - request_output: Union[RequestOutput, - PoolingRequestOutput], + request_output: RequestOutput, *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -261,9 +257,7 @@ class _AsyncLLMEngine(LLMEngine): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def step_async( - self, virtual_engine: int - ) -> List[Union[RequestOutput, PoolingRequestOutput]]: + async def step_async(self, virtual_engine: int) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -405,7 +399,7 @@ class _AsyncLLMEngine(LLMEngine): self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, @@ -486,10 +480,10 @@ class AsyncLLMEngine(EngineClient): _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, - *args, + *args: Any, log_requests: bool = True, start_engine_loop: bool = True, - **kwargs) -> None: + **kwargs: Any) -> None: if envs.VLLM_USE_V1: raise ValueError( "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. " @@ -779,14 +773,14 @@ class AsyncLLMEngine(EngineClient): self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + ) -> AsyncGenerator[RequestOutput, None]: if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -908,7 +902,7 @@ class AsyncLLMEngine(EngineClient): await self.abort(request_id) raise - async def encode( + def encode( self, prompt: PromptType, pooling_params: PoolingParams, @@ -918,85 +912,8 @@ class AsyncLLMEngine(EngineClient): priority: int = 0, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: - """Generate outputs for a request from a pooling model. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - pooling_params: The pooling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - - Yields: - The output `PoolingRequestOutput` objects from the LLMEngine - for the request. - - Details: - - If the engine is not running, start the background loop, - which iteratively invokes - [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][] - to process the waiting requests. - - Add the request to the engine's `RequestTracker`. - On the next background loop, this request will be sent to - the underlying engine. - Also, a corresponding `AsyncStream` will be created. - - Wait for the request outputs from `AsyncStream` and yield them. - - Example: - ``` - # Please refer to entrypoints/api_server.py for - # the complete example. - - # initialize the engine and the example input - # note that engine_args here is AsyncEngineArgs instance - engine = AsyncLLMEngine.from_engine_args(engine_args) - example_input = { - "input": "What is LLM?", - "request_id": 0, - } - - # start the generation - results_generator = engine.encode( - example_input["input"], - PoolingParams(), - example_input["request_id"]) - - # get the results - final_output = None - async for request_output in results_generator: - if await request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - # Return or raise an error - ... - final_output = request_output - - # Process and return the final output - ... - ``` - """ - try: - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - tokenization_kwargs=tokenization_kwargs, - ): - yield LLMEngine.validate_output(output, PoolingRequestOutput) - except asyncio.CancelledError: - await self.abort(request_id) - raise + raise NotImplementedError( + "Pooling models are not supported in vLLM V0") async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort a request. @@ -1104,8 +1021,8 @@ class AsyncLLMEngine(EngineClient): async def is_sleeping(self) -> bool: return self.engine.is_sleeping() - async def add_lora(self, lora_request: LoRARequest) -> None: - self.engine.add_lora(lora_request) + async def add_lora(self, lora_request: LoRARequest) -> bool: + return self.engine.add_lora(lora_request) async def collective_rpc(self, method: str, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ac2214e7aa033981f9db1cda8b5b64a5fd152818..0f1aa31ac282c5c08460bad8c00e04e2b2015a54 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,15 +40,15 @@ from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) -from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, - PoolingSequenceGroupOutput, Sequence, SequenceGroup, - SequenceGroupBase, SequenceGroupMetadata, - SequenceGroupOutput, SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID) + Sequence, SequenceGroup, SequenceGroupBase, + SequenceGroupMetadata, SequenceGroupOutput, + SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.detokenizer import Detokenizer @@ -99,8 +99,7 @@ class SchedulerContext: def __init__(self) -> None: self.output_queue: Deque[OutputData] = deque() - self.request_outputs: List[Union[RequestOutput, - PoolingRequestOutput]] = [] + self.request_outputs: List[RequestOutput] = [] self.seq_group_metadata_list: Optional[ List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None @@ -261,14 +260,17 @@ class LLMEngine: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=processor_only_cache_from_config( + self.model_config, mm_registry), + ) self.model_executor = executor_class(vllm_config=vllm_config) - if self.model_config.runner_type != "pooling": - self._initialize_kv_caches() + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -553,7 +555,7 @@ class LLMEngine: self, request_id: str, processed_inputs: ProcessorInputs, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], trace_headers: Optional[Mapping[str, str]] = None, @@ -593,7 +595,7 @@ class LLMEngine: encoder_seq = (None if encoder_inputs is None else Sequence( seq_id, encoder_inputs, block_size, eos_token_id, lora_request)) - # Create a SequenceGroup based on SamplingParams or PoolingParams + # Create a SequenceGroup based on SamplingParams if isinstance(params, SamplingParams): seq_group = self._create_sequence_group_with_sampling( request_id, @@ -604,18 +606,8 @@ class LLMEngine: trace_headers=trace_headers, encoder_seq=encoder_seq, priority=priority) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - encoder_seq=encoder_seq, - priority=priority) else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") + raise ValueError("SamplingParams must be provided.") # Add the sequence group to the scheduler with least unfinished seqs. costs = [ @@ -634,7 +626,7 @@ class LLMEngine: self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, @@ -652,9 +644,8 @@ class LLMEngine: prompt: The prompt to the LLM. See [PromptType][vllm.inputs.PromptType] for more details about the format of each input. - params: Parameters for sampling or pooling. + params: Parameters for sampling. [SamplingParams][vllm.SamplingParams] for text generation. - [PoolingParams][vllm.PoolingParams] for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. lora_request: The LoRA request to add. @@ -665,10 +656,10 @@ class LLMEngine: Details: - Set arrival_time to the current time if it is None. - Set prompt_token_ids to the encoded prompt if it is None. - - Create `n` number of [Sequence][vllm.Sequence] objects. - - Create a [SequenceGroup][vllm.SequenceGroup] object - from the list of [Sequence][vllm.Sequence]. - - Add the [SequenceGroup][vllm.SequenceGroup] object to the + - Create `n` number of [Sequence][vllm.sequence.Sequence] objects. + - Create a [SequenceGroup][vllm.sequence.SequenceGroup] object + from the list of [Sequence][vllm.sequence.Sequence]. + - Add the [SequenceGroup][vllm.sequence.SequenceGroup] object to the scheduler. Example: @@ -780,29 +771,6 @@ class LLMEngine: return seq_group - def _create_sequence_group_with_pooling( - self, - request_id: str, - seq: Sequence, - pooling_params: PoolingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - encoder_seq: Optional[Sequence] = None, - priority: int = 0, - ) -> SequenceGroup: - """Creates a SequenceGroup with PoolingParams.""" - # Defensive copy of PoolingParams, which are used by the pooler - pooling_params = pooling_params.clone() - # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params, - encoder_seq=encoder_seq, - priority=priority) - return seq_group - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -865,8 +833,8 @@ class LLMEngine: def reset_mm_cache(self) -> bool: """Reset the multi-modal cache.""" - return self.input_preprocessor.mm_registry.reset_processor_cache( - self.model_config) + self.input_preprocessor.clear_cache() + return True def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" @@ -876,18 +844,6 @@ class LLMEngine: success = success and scheduler.reset_prefix_cache(device) return success - @staticmethod - def _process_sequence_group_outputs( - seq_group: SequenceGroup, - outputs: List[PoolingSequenceGroupOutput], - ) -> None: - seq_group.pooled_data = outputs[0].data - - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_STOPPED - - return - def _process_model_outputs(self, ctx: SchedulerContext, request_id: Optional[str] = None) -> None: @@ -994,13 +950,10 @@ class LLMEngine: seq_group.metrics.model_execute_time = ( o.model_execute_time) - if self.model_config.runner_type == "pooling": - self._process_sequence_group_outputs(seq_group, output) - else: - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs( - seq_group, output, is_async) + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs(seq_group, output, + is_async) if seq_group.is_finished(): finished_now.append(i) @@ -1122,7 +1075,7 @@ class LLMEngine: seq.append_token_id(sample.output_token, sample.logprobs, sample.output_embed) - def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: + def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. <figure markdown="span"> @@ -1323,7 +1276,7 @@ class LLMEngine: # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks + # torch.distributed ops which may otherwise time out, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. logger.debug("Stopping remote worker execution loop.") @@ -1865,7 +1818,7 @@ class LLMEngine: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ff0405d2f843e9fc95c993ff9c852ad1ae782c00..9f64ee0808df2da0e985b026c19689a37e522ea4 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -120,6 +120,7 @@ class RPCLoadAdapterRequest: @dataclass class RPCAdapterLoadedResponse: request_id: str + lora_loaded: bool RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 588a03a3beb4ef61405386c3793ff6d8ff6f506a..20d953adeeef6c0fe190235dc889d7168add755a 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -6,7 +6,7 @@ import copy import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, - Mapping, Optional, Union, cast) + Mapping, Optional, Union) import cloudpickle import psutil @@ -482,10 +482,8 @@ class MQLLMEngineClient(EngineClient): Any priority other than 0 will lead to an error if the scheduling policy is not "priority". """ - return cast( - AsyncGenerator[RequestOutput, None], - self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, priority)) + return self._process_request(prompt, sampling_params, request_id, + lora_request, trace_headers, priority) def encode( self, @@ -495,45 +493,20 @@ class MQLLMEngineClient(EngineClient): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: - """Generate outputs for a request from a pooling model. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - pooling_params: The pooling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - - Yields: - The output `PoolingRequestOutput` objects from the LLMEngine - for the request. - """ - return cast( - AsyncGenerator[PoolingRequestOutput, None], - self._process_request(prompt, - pooling_params, - request_id, - lora_request, - trace_headers, - priority=priority)) + raise NotImplementedError( + "Pooling models are not supported in vLLM V0") async def _process_request( self, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - PoolingRequestOutput, None]]: + ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. @@ -544,7 +517,7 @@ class MQLLMEngineClient(EngineClient): if request_id in self.output_queues: raise ValueError(f"Request {request_id} already exists") - # 1) Create output queue for this requests. + # 1) Create output queue for this request. queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue @@ -552,7 +525,7 @@ class MQLLMEngineClient(EngineClient): try: # 2) Detach logits processors so that they can be pickled # separately (may require cloudpickle which is slower) - if isinstance(params, SamplingParams) and params.logits_processors: + if params.logits_processors: # Defensive shallow copy params = copy.copy(params) logits_processors = params.logits_processors @@ -651,13 +624,14 @@ class MQLLMEngineClient(EngineClient): raise request_output return request_output.is_sleeping - async def add_lora(self, lora_request: LoRARequest) -> None: + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" # Uses the same I/O as generate requests request = RPCLoadAdapterRequest(lora_request) - # Create output queue for this requests. - queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() + # Create output queue for this request. + queue: asyncio.Queue[Union[ + BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue() self.output_queues[request.request_id] = queue # Send the request @@ -671,3 +645,4 @@ class MQLLMEngineClient(EngineClient): # Raise on error, otherwise happily return None if isinstance(request_output, BaseException): raise request_output + return request_output.lora_loaded diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 02927093e5b69e584e3b208695e60f14044257be..8d9ea9cc8fb96ede63565b81151189f04384b07f 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -365,7 +365,7 @@ class MQLLMEngine: def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): try: - self.engine.add_lora(request.lora_request) + lora_loaded = self.engine.add_lora(request.lora_request) except BaseException as e: # Send back an error if the adater fails to load rpc_err = RPCError(request_id=request.request_id, @@ -375,7 +375,8 @@ class MQLLMEngine: return # Otherwise, send back the successful load message self._send_outputs( - RPCAdapterLoadedResponse(request_id=request.request_id)) + RPCAdapterLoadedResponse(request_id=request.request_id, + lora_loaded=lora_loaded)) def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): is_sleeping = self.is_sleeping() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index c610fb5eae60c6d3d38d180da624769e1817827b..b0b11a33a444398ad8048cfd6ce26f8b48886c37 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, Iterable, Mapping, Optional, Union +from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig, VllmConfig @@ -15,6 +15,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors.interface import IOProcessor from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -224,6 +225,7 @@ class EngineClient(ABC): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model.""" ... @@ -266,6 +268,9 @@ class EngineClient(ABC): """Get the appropriate tokenizer for the request""" ... + async def get_io_processor(self) -> IOProcessor: + raise NotImplementedError + @abstractmethod async def is_tracing_enabled(self) -> bool: ... @@ -320,7 +325,7 @@ class EngineClient(ABC): ... @abstractmethod - async def add_lora(self, lora_request: LoRARequest) -> None: + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" ... @@ -329,3 +334,11 @@ class EngineClient(ABC): drain_timeout: int = 300) -> None: """Scale the engine""" raise NotImplementedError + + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """Perform a collective RPC call to the given path.""" + raise NotImplementedError diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 74c8093f496748aec6e122cca45613b6d9a01f2e..1954cbcbf1edd213e03b44124a500ada587b9aa6 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -640,7 +640,7 @@ class BaseMultiModalContentParser(ABC): def __init__(self) -> None: super().__init__() - # stores model placehodlers list with corresponding + # stores model placeholders list with corresponding # general MM placeholder: # { # "<##IMAGE##>": ["<image>", "<image>", "<image>"], @@ -1330,7 +1330,7 @@ def apply_mistral_chat_template( # mistral-common uses assert statements to stop processing of input # if input does not comply with the expected format. # We convert those assertion errors to ValueErrors so they can be - # are properly caught in the preprocessing_input step + # properly caught in the preprocessing_input step except (AssertionError, MistralCommonException) as e: raise ValueError(str(e)) from e @@ -1345,5 +1345,18 @@ def apply_mistral_chat_template( "template") raise ValueError(str(e)) from e -def random_tool_call_id() -> str: - return f"chatcmpl-tool-{random_uuid()}" +def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): + idx = 0 + for msg in conversation: + if msg['role'] == 'assistant': + tool_calls = msg.get('tool_calls') + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + return idx + +def make_tool_call_id(id_type:str='random', func_name=None, idx=None): + + if id_type=='kimi_k2': + return f'functions.{func_name}:{idx}' + else: + # by default return random + return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index e817f07ef59474077acc2241e8c5cdf3d0d09bb8..9d587e8669339c1308876e16c9f90b069da3572b 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -3,13 +3,16 @@ import json import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Union +from collections.abc import Sequence +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Optional, Union from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.harmony_utils import ( get_encoding, get_streamable_parser_for_assistant, render_for_completion) from vllm.entrypoints.tool import Tool +from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput if TYPE_CHECKING: @@ -36,6 +39,11 @@ class ConversationContext(ABC): def render_for_completion(self) -> list[int]: pass + @abstractmethod + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + class SimpleContext(ConversationContext): @@ -54,28 +62,45 @@ class SimpleContext(ConversationContext): def render_for_completion(self) -> list[int]: raise NotImplementedError("Should not be called.") + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + class HarmonyContext(ConversationContext): def __init__( self, messages: list, - tool_sessions: dict[str, Tool], + available_tools: list[str], ): self._messages = messages - self.tool_sessions = tool_sessions + self.available_tools = available_tools + self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) - # TODO(woosuk): Implement the following fields. self.num_prompt_tokens = 0 - self.num_cached_tokens = 0 self.num_output_tokens = 0 + # TODO(woosuk): Implement the following fields. + self.num_cached_tokens = 0 self.num_reasoning_tokens = 0 + def _update_num_prompt_tokens(self, output: RequestOutput): + if output.prompt_token_ids and len(output.prompt_token_ids) > 0: + # NOTE: with built-in tools, there might be multiple rounds in + # the conversation, with the full conversation being resent + # as new prompt each time. Hence the sum. + self.num_prompt_tokens += len(output.prompt_token_ids) + + def _update_num_output_tokens(self, token_ids: Sequence[int]): + self.num_output_tokens += len(token_ids) + def append_output(self, output) -> None: if isinstance(output, RequestOutput): + self._update_num_prompt_tokens(output) output_token_ids = output.outputs[0].token_ids + self._update_num_output_tokens(output_token_ids) self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) @@ -103,10 +128,10 @@ class HarmonyContext(ConversationContext): if recipient is not None: if recipient.startswith("browser."): return await self.call_search_tool( - self.tool_sessions["browser"], last_msg) + self._tool_sessions["browser"], last_msg) elif recipient.startswith("python"): return await self.call_python_tool( - self.tool_sessions["python"], last_msg) + self._tool_sessions["python"], last_msg) raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: @@ -148,6 +173,15 @@ class HarmonyContext(ConversationContext): recipient=Role.ASSISTANT) ] + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + if tool_server: + for tool_name in self.available_tools: + if tool_name not in self._tool_sessions: + self._tool_sessions[ + tool_name] = await exit_stack.enter_async_context( + tool_server.new_session(tool_name)) + class StreamingHarmonyContext(HarmonyContext): @@ -158,6 +192,7 @@ class StreamingHarmonyContext(HarmonyContext): self.parser = get_streamable_parser_for_assistant() self.encoding = get_encoding() self.last_tok = None + self.first_tok_of_message = True @property def messages(self) -> list: @@ -165,8 +200,18 @@ class StreamingHarmonyContext(HarmonyContext): def append_output(self, output) -> None: if isinstance(output, RequestOutput): + # append_output is called for each output token in streaming case, + # so we only want to add the prompt tokens once for each message. + if self.first_tok_of_message: + self._update_num_prompt_tokens(output) + # Reset self.first_tok_of_message if needed: + # if the current token is the last one of the current message + # (finished=True), then the next token processed will mark the + # beginning of a new message + self.first_tok_of_message = output.finished tok = output.outputs[0].token_ids[0] self.parser.process(tok) + self._update_num_output_tokens(output.outputs[0].token_ids) self.last_tok = tok else: # Handle the case of tool output in direct message format diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index efca1472e44cfd9e81861d9323ba3b4f9b9e96f7..078d316844257b508dc51df6db46c8d5f2d0f8d5 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -155,7 +155,7 @@ def parse_chat_input(chat_msg) -> Message: contents = [TextContent(text=content)] else: # TODO: Support refusal. - contents = [TextContent(text=c["text"]) for c in content] + contents = [TextContent(text=c.get("text", "")) for c in content] msg = Message.from_role_and_contents(role, contents) return msg @@ -218,8 +218,8 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: ) output_items.append(reasoning_item) elif message.channel == "commentary": - if message.recipient.startswith("functions."): - function_name = message.recipient.split(".")[-1] + if recipient is not None and recipient.startswith("functions."): + function_name = recipient.split(".")[-1] for content in message.content: random_id = random_uuid() response_item = ResponseFunctionToolCall( @@ -230,8 +230,8 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: id=f"ft_{random_id}", ) output_items.append(response_item) - elif message.recipient.startswith( - "python") or message.recipient.startswith("browser"): + elif recipient is not None and (recipient.startswith("python") + or recipient.startswith("browser")): for content in message.content: reasoning_item = ResponseReasoningItem( id=f"rs_{random_uuid()}", @@ -245,7 +245,7 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: ) output_items.append(reasoning_item) else: - raise ValueError(f"Unknown recipient: {message.recipient}") + raise ValueError(f"Unknown recipient: {recipient}") elif message.channel == "final": contents = [] for content in message.content: @@ -329,23 +329,19 @@ def parse_chat_output( token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]: parser = parse_output_into_messages(token_ids) output_msgs = parser.messages + is_tool_call = False # TODO: update this when tool call is supported if len(output_msgs) == 0: # The generation has stopped during reasoning. - is_tool_call = False reasoning_content = parser.current_content final_content = None elif len(output_msgs) == 1: # The generation has stopped during final message. - is_tool_call = False reasoning_content = output_msgs[0].content[0].text final_content = parser.current_content else: - if len(output_msgs) != 2: - raise ValueError( - "Expected 2 output messages (reasoning and final), " - f"but got {len(output_msgs)}.") - reasoning_msg, final_msg = output_msgs - reasoning_content = reasoning_msg.content[0].text + reasoning_msg = output_msgs[:-1] + final_msg = output_msgs[-1] + reasoning_content = "\n".join( + [msg.content[0].text for msg in reasoning_msg]) final_content = final_msg.content[0].text - is_tool_call = final_msg.recipient is not None return reasoning_content, final_content, is_tool_call diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cb71e31c9a037996545e8767658c978c7f4f41b8..42367a92a3b03eb7cac8ef9126ae9fc7cb5877cb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,15 +3,13 @@ import itertools from collections.abc import Sequence -from contextlib import contextmanager -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, - cast, overload) +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast import cloudpickle import torch.nn as nn from pydantic import ValidationError from tqdm.auto import tqdm -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeVar import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -39,14 +37,15 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam, # yapf: enable from vllm.entrypoints.utils import (_validate_truncation_size, log_non_default_args) -from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt, + TokensPrompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput, ScoringRequestOutput) +from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, SamplingParams) @@ -54,9 +53,7 @@ from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext - -from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of - +from vllm.utils import Counter, Device, as_iter, is_list_of from vllm.v1.sample.logits_processor import LogitsProcessor import vllm.envs as envs @@ -163,18 +160,6 @@ class LLM: serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead. """ - DEPRECATE_LEGACY: ClassVar[bool] = True - """A flag to toggle whether to deprecate the legacy generate/encode API.""" - - @classmethod - @contextmanager - def deprecate_legacy_api(cls): - cls.DEPRECATE_LEGACY = True - - yield - - cls.DEPRECATE_LEGACY = False - def __init__( self, model: str, @@ -209,7 +194,7 @@ class LLM: CompilationConfig]] = None, logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None, - **kwargs, + **kwargs: Any, ) -> None: """LLM constructor.""" @@ -311,6 +296,11 @@ class LLM: self.supported_tasks = supported_tasks + # Load the Input/Output processor plugin if any + io_processor_plugin = self.llm_engine.model_config.io_processor_plugin + self.io_processor = get_io_processor(self.llm_engine.vllm_config, + io_processor_plugin) + def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, @@ -337,99 +327,14 @@ class LLM: return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() - @overload def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - /, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: str, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: list[str], - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[str] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[int], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[list[str]] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: None, - sampling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def generate( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, priority: Optional[list[int]] = None, ) -> list[RequestOutput]: """Generates the completions for the input prompts. @@ -441,7 +346,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -472,37 +377,19 @@ class LLM: "Try passing `--runner generate` to use the model as a " "generative model.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) - if sampling_params is None: # Use default sampling params. sampling_params = self.get_default_sampling_params() - tokenization_kwargs: dict[str, Any] = {} - truncate_prompt_tokens = None - if isinstance(sampling_params, SamplingParams): - truncate_prompt_tokens = sampling_params.truncate_prompt_tokens - - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) - # Add any modality specific loras to the corresponding prompts lora_request = self._get_modality_specific_lora_reqs( - parsed_prompts, lora_request) + prompts, lora_request) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, priority=priority, ) @@ -510,7 +397,7 @@ class LLM: return self.engine_class.validate_outputs(outputs, RequestOutput) def _get_modality_specific_lora_reqs( - self, parsed_prompts: Union[PromptType, Sequence[PromptType]], + self, prompts: Union[PromptType, Sequence[PromptType]], lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): # Grab the lora config off the vllm config on the engine, # since this is the same for both v0 & v1. @@ -523,35 +410,33 @@ class LLM: or (lora_config and lora_config.default_mm_loras is None)): return lora_request - if not isinstance(parsed_prompts, Sequence): - parsed_prompts = [parsed_prompts] + if not isinstance(prompts, Sequence): + prompts = [prompts] - optional_loras = ([lora_request] * len(parsed_prompts) + optional_loras = ([lora_request] * len(prompts) if not isinstance(lora_request, Sequence) else lora_request) return [ self._resolve_single_prompt_mm_lora( - parsed_prompt, + prompt, opt_lora_req, lora_config.default_mm_loras, - ) for parsed_prompt, opt_lora_req in zip(parsed_prompts, - optional_loras) + ) for prompt, opt_lora_req in zip(prompts, optional_loras) ] - def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType, + def _resolve_single_prompt_mm_lora(self, prompt: PromptType, lora_request: Optional[LoRARequest], default_mm_loras: Optional[dict[str, str]]): - if (not default_mm_loras or not isinstance(parsed_prompt, dict) - or "multi_modal_data" not in parsed_prompt): + if (not default_mm_loras or not isinstance(prompt, dict) + or "multi_modal_data" not in prompt): return lora_request - parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt) + prompt = cast(Union[TextPrompt, TokensPrompt], prompt) - intersection = set( - parsed_prompt["multi_modal_data"].keys()).intersection( - default_mm_loras.keys()) + intersection = set(prompt["multi_modal_data"].keys()) \ + .intersection(default_mm_loras.keys()) if not intersection: return lora_request if len(intersection) > 1: @@ -646,6 +531,7 @@ class LLM: params: BeamSearchParams, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, use_tqdm: bool = False, + concurrency_limit: Optional[int] = None, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -656,6 +542,8 @@ class LLM: params: The beam search parameters. lora_request: LoRA request to use for generation, if any. use_tqdm: Whether to use tqdm to display the progress bar. + concurrency_limit: The maximum number of concurrent requests. + If None, the number of concurrent requests is unlimited. """ # TODO: how does beam search work together with length penalty, # frequency, penalty, and stopping criteria, etc.? @@ -674,6 +562,15 @@ class LLM: length_penalty, ) + if use_tqdm and concurrency_limit is not None: + logger.warning( + "Progress bar is not supported when using concurrency_limit. " + "Disabling progress bar.") + use_tqdm = False + + if concurrency_limit is None: + concurrency_limit = len(prompts) + def create_tokens_prompt_from_beam( beam: BeamSearchSequence) -> TokensPrompt: token_prompt_kwargs: TokensPrompt = { @@ -718,73 +615,79 @@ class LLM: **mm_kwargs, ), ) - token_iter = range(max_tokens) - if use_tqdm: - token_iter = tqdm(token_iter, - desc="Beam search", - unit="token", - unit_scale=False) - logger.warning( - "The progress bar shows the upper bound on token steps and " - "may finish early due to stopping conditions. It does not " - "reflect instance-level progress.") - - for _ in token_iter: - all_beams: list[BeamSearchSequence] = list( - sum((instance.beams for instance in instances), [])) - pos = [0] + list( - itertools.accumulate( - len(instance.beams) for instance in instances)) - instance_start_and_end: list[tuple[int, int]] = list( - zip(pos[:-1], pos[1:])) - - if len(all_beams) == 0: - break - - # create the corresponding batch entries for prompt & optional lora - prompts_batch, lora_req_batch = zip( - *[(create_tokens_prompt_from_beam(beam), beam.lora_request) - for beam in all_beams]) - - # only runs for one step - # we don't need to use tqdm here - output = self.generate(prompts_batch, - sampling_params=beam_search_params, - use_tqdm=False, - lora_request=lora_req_batch) - - for (start, end), instance in zip(instance_start_and_end, - instances): - instance_new_beams = [] - for i in range(start, end): - current_beam = all_beams[i] - result = output[i] - - if result.outputs[0].logprobs is not None: - # if `result.outputs[0].logprobs` is None, it means - # the sequence is completed because of the max-model-len - # or abortion. we don't need to add it to the new beams. - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam.multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs) - - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - instance.completed.append(new_beam) - else: - instance_new_beams.append(new_beam) - sorted_beams = sorted(instance_new_beams, - key=sort_beams_key, - reverse=True) - instance.beams = sorted_beams[:beam_width] + for prompt_start in range(0, len(prompts), concurrency_limit): + instances_batch = instances[prompt_start:prompt_start + + concurrency_limit] + + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm(token_iter, + desc="Beam search", + unit="token", + unit_scale=False) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress.") + for _ in token_iter: + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances_batch), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances_batch)) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + # create corresponding batch entries for prompt & optional lora + prompts_batch, lora_req_batch = zip( + *[(create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams]) + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False, + lora_request=lora_req_batch) + + for (start, end), instance in zip(instance_start_and_end, + instances_batch): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the + # max-model-len or abortion. we don't need to add + # it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] outputs = [] for instance in instances: @@ -820,8 +723,8 @@ class LLM: Generate responses for a chat conversation. The chat conversation is converted into a text prompt using the - tokenizer and calls the [generate][] method to generate the - responses. + tokenizer and calls the [generate][vllm.LLM.generate] method to generate + the responses. Multi-modal inputs can be passed in the same way you would pass them to the OpenAI API. @@ -945,11 +848,9 @@ class LLM: lora_request=lora_request, ) - @overload def encode( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + prompts: Union[PromptType, Sequence[PromptType], DataPrompt], pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, @@ -958,107 +859,6 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, pooling_task: PoolingTask = "encode", tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: str, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: list[str], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[str] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[int], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[list[str]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: None, - pooling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def encode( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: Optional[PoolingTask] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1070,7 +870,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. @@ -1079,6 +879,8 @@ class LLM: If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. pooling_task: Override the pooling task to use. + tokenization_kwargs: overrides tokenization_kwargs set in + pooling_params Returns: A list of `PoolingRequestOutput` objects containing the @@ -1120,47 +922,62 @@ class LLM: raise ValueError( f"pooling_task must be one of {self.supported_tasks}.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) - if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - if isinstance(pooling_params, PoolingParams): - pooling_params.verify(pooling_task, model_config) - else: - for pooling_param in pooling_params: - pooling_param.verify(pooling_task, model_config) + for param in as_iter(pooling_params): + param.verify(pooling_task, model_config) + # for backwards compatibility + if truncate_prompt_tokens is not None: + param.truncate_prompt_tokens = truncate_prompt_tokens - if tokenization_kwargs is None: - tokenization_kwargs = dict[str, Any]() - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, - tokenization_kwargs) + io_processor_prompt = False + if isinstance(prompts, dict) and "data" in prompts: + io_processor_prompt = True + if self.io_processor is None: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details.") + + # Validate the request data is valid for the loaded plugin + validated_prompt = self.io_processor.parse_request(prompts) + + # obtain the actual model prompts from the pre-processor + prompts = self.io_processor.pre_process(prompt=validated_prompt) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, ) outputs = self._run_engine(use_tqdm=use_tqdm) - return self.engine_class.validate_outputs(outputs, - PoolingRequestOutput) + + model_outputs = self.engine_class.validate_outputs( + outputs, PoolingRequestOutput) + + if io_processor_prompt: + # get the post-processed model outputs + assert self.io_processor is not None + processed_outputs = self.io_processor.post_process( + model_output=model_outputs) + + return [ + PoolingRequestOutput[Any](request_id="", + outputs=processed_outputs, + prompt_token_ids=[], + finished=True) + ] + else: + return model_outputs def embed( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, @@ -1178,7 +995,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. @@ -1210,7 +1027,6 @@ class LLM: def classify( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, pooling_params: Optional[Union[PoolingParams, @@ -1227,7 +1043,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. @@ -1271,7 +1087,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. @@ -1360,7 +1176,7 @@ class LLM: _validate_truncation_size(model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) - parsed_prompts = [] + prompts = list[PromptType]() input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] @@ -1375,8 +1191,7 @@ class LLM: tokenization_kwargs=tokenization_kwargs, ) - if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( - "token_type_ids", None)): + if (token_type_ids := engine_prompt.pop("token_type_ids", None)): params = pooling_params.clone() compressed = compress_token_type_ids(token_type_ids) params.extra_kwargs = {"compressed_token_type_ids": compressed} @@ -1384,10 +1199,10 @@ class LLM: else: pooling_params_list.append(pooling_params) - parsed_prompts.append(engine_prompt) + prompts.append(engine_prompt) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=pooling_params_list, use_tqdm=use_tqdm, lora_request=lora_request, @@ -1571,8 +1386,8 @@ class LLM: def wake_up(self, tags: Optional[list[str]] = None): """ - Wake up the engine from sleep mode. See the [sleep][] method - for more details. + Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep] + method for more details. Args: tags: An optional list of tags to reallocate the engine memory @@ -1597,48 +1412,6 @@ class LLM: assert isinstance(self.llm_engine, V1LLMEngine) return self.llm_engine.get_metrics() - # LEGACY - def _convert_v1_inputs( - self, - prompts: Optional[Union[str, list[str]]], - prompt_token_ids: Optional[Union[list[int], list[list[int]]]], - ): - # skip_tokenizer_init is now checked in engine - - if prompts is None and prompt_token_ids is None: - raise ValueError( - "Either prompts or prompt_token_ids must be provided.") - if prompts is not None and prompt_token_ids is not None \ - and len(prompts) != len(prompt_token_ids): - raise ValueError( - "The lengths of prompts and prompt_token_ids must be the same." - ) - - if prompts is not None: - prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] - if prompt_token_ids is not None: - prompt_token_ids = [ - p["content"] for p in parse_and_batch_prompt(prompt_token_ids) - ] - if prompts is not None: - num_requests = len(prompts) - elif prompt_token_ids is not None: - num_requests = len(prompt_token_ids) - parsed_prompts: list[PromptType] = [] - for i in range(num_requests): - item: PromptType - - if prompts is not None: - item = TextPrompt(prompt=prompts[i]) - elif prompt_token_ids is not None: - item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) - else: - raise AssertionError - - parsed_prompts.append(item) - - return parsed_prompts - def _validate_and_add_requests( self, prompts: Union[PromptType, Sequence[PromptType]], @@ -1647,7 +1420,6 @@ class LLM: *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], - tokenization_kwargs: Optional[dict[str, Any]] = None, priority: Optional[list[int]] = None, ) -> None: if isinstance(prompts, (str, dict)): @@ -1674,7 +1446,17 @@ class LLM: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") + model_config = self.llm_engine.model_config + for i, prompt in enumerate(it): + + param = params[i] if isinstance(params, Sequence) else params + + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(model_config.max_model_len, + param.truncate_prompt_tokens, + tokenization_kwargs) + self._add_request( prompt, params[i] if isinstance(params, Sequence) else params, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3b9244ac6059a24ad080f0ae7883335fcccab759..3cebfdf885becd553d3ed462f173cfc1c6db85b7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -64,6 +64,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, EmbeddingRequest, EmbeddingResponse, ErrorInfo, ErrorResponse, + IOProcessorResponse, LoadLoRAAdapterRequest, PoolingRequest, PoolingResponse, RerankRequest, RerankResponse, @@ -600,8 +601,11 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Responses API") - - generator = await handler.create_responses(request, raw_request) + try: + generator = await handler.create_responses(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -618,7 +622,11 @@ async def retrieve_responses(response_id: str, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.retrieve_responses(response_id) + try: + response = await handler.retrieve_responses(response_id) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), @@ -633,7 +641,11 @@ async def cancel_responses(response_id: str, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.cancel_responses(response_id) + try: + response = await handler.cancel_responses(response_id) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), @@ -667,9 +679,11 @@ async def create_chat_completion(request: ChatCompletionRequest, if handler is None: return base(raw_request).create_error_response( message="The model does not support Chat Completions API") - - generator = await handler.create_chat_completion(request, raw_request) - + try: + generator = await handler.create_chat_completion(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -742,7 +756,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Embeddings API") - generator = await handler.create_embedding(request, raw_request) + try: + generator = await handler.create_embedding(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -770,12 +788,15 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Pooling API") - - generator = await handler.create_pooling(request, raw_request) + try: + generator = await handler.create_pooling(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) - elif isinstance(generator, PoolingResponse): + elif isinstance(generator, (PoolingResponse, IOProcessorResponse)): return JSONResponse(content=generator.model_dump()) assert_never(generator) @@ -791,7 +812,11 @@ async def create_classify(request: ClassificationRequest, return base(raw_request).create_error_response( message="The model does not support Classification API") - generator = await handler.create_classify(request, raw_request) + try: + generator = await handler.create_classify(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -820,7 +845,11 @@ async def create_score(request: ScoreRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Score API") - generator = await handler.create_score(request, raw_request) + try: + generator = await handler.create_score(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -878,8 +907,12 @@ async def create_transcriptions(raw_request: Request, message="The model does not support Transcriptions API") audio_data = await request.file.read() - generator = await handler.create_transcription(audio_data, request, - raw_request) + try: + generator = await handler.create_transcription(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -919,8 +952,12 @@ async def create_translations(request: Annotated[TranslationRequest, message="The model does not support Translations API") audio_data = await request.file.read() - generator = await handler.create_translation(audio_data, request, - raw_request) + try: + generator = await handler.create_translation(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -949,7 +986,11 @@ async def do_rerank(request: RerankRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Rerank (Score) API") - generator = await handler.do_rerank(request, raw_request) + try: + generator = await handler.do_rerank(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.error.code) @@ -1044,6 +1085,34 @@ if envs.VLLM_SERVER_DEV_MODE: is_sleeping = await engine_client(raw_request).is_sleeping() return JSONResponse(content={"is_sleeping": is_sleeping}) + @router.post("/collective_rpc") + async def collective_rpc(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}") from e + method = body.get("method") + if method is None: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'method' in request body") + # For security reason, only serialized string args/kwargs are passed. + # User-defined `method` is responsible for deserialization if needed. + args: list[str] = body.get("args", []) + kwargs: dict[str, str] = body.get("kwargs", {}) + timeout: Optional[float] = body.get("timeout") + results = await engine_client(raw_request).collective_rpc( + method=method, timeout=timeout, args=tuple(args), kwargs=kwargs) + if results is None: + return Response(status_code=200) + response: list[Any] = [] + for result in results: + if result is None or isinstance(result, (dict, list)): + response.append(result) + else: + response.append(str(result)) + return JSONResponse(content={"results": response}) + @router.post("/scale_elastic_ep", dependencies=[Depends(validate_json_request)], @@ -1680,6 +1749,8 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_chat = OpenAIServingChat( engine_client, @@ -1697,6 +1768,8 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, @@ -1706,14 +1779,16 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, - model_config, + vllm_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "encode" in supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, @@ -1722,23 +1797,22 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "embed" in supported_tasks else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks else None - - enable_serving_reranking = ("classify" in supported_tasks and getattr( - model_config.hf_config, "num_labels", 0) == 1) state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if ("embed" in supported_tasks or enable_serving_reranking) else None - + log_error_stack=args.log_error_stack, + ) if ("embed" in supported_tasks or "score" in supported_tasks) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, @@ -1746,18 +1820,21 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) state.openai_serving_transcription = OpenAIServingTranscription( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.openai_serving_translation = OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 6e4eff5c8024377fa2daf511607858c04a39ad8b..d0b5d013eb9e53d668e060f52a5df42c5978525c 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -180,6 +180,8 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT """Maximum number of HTTP headers allowed in a request for h11 parser. Helps mitigate header abuse. Default: 256.""" + log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE + """If set to True, log the stack trace of error responses""" @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 61f1a09d3ac1c5d8b5700fd43bdda26c800b1e2d..12b274e1211bcba2d20183c7aef69756952125d8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,8 @@ import json import time from http import HTTPStatus -from typing import Annotated, Any, ClassVar, Literal, Optional, Union +from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional, + TypeVar, Union) import regex as re import torch @@ -38,7 +39,7 @@ from typing_extensions import TypeAlias from vllm import envs from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - random_tool_call_id) + make_tool_call_id) from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam) from vllm.logger import init_logger @@ -357,13 +358,22 @@ class ResponsesRequest(OpenAIBaseModel): temperature=temperature, top_p=top_p, max_tokens=max_tokens, - logprobs=self.top_logprobs, + logprobs=self.top_logprobs + if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, output_kind=(RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY), guided_decoding=guided_decoding, ) + def is_include_output_logprobs(self) -> bool: + """Check if the request includes output logprobs.""" + if self.include is None: + return False + return isinstance( + self.include, + list) and "message.output_text.logprobs" in self.include + @model_validator(mode="before") def validate_background(cls, data): if not data.get("background"): @@ -443,7 +453,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) @@ -576,6 +586,14 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, description=( @@ -978,7 +996,7 @@ class CompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None allowed_token_ids: Optional[list[int]] = None prompt_logprobs: Optional[int] = None # --8<-- [end:completion-sampling-params] @@ -1062,6 +1080,14 @@ class CompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, @@ -1300,8 +1326,10 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1368,15 +1396,56 @@ class EmbeddingChatRequest(OpenAIBaseModel): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] PoolingCompletionRequest = EmbeddingCompletionRequest PoolingChatRequest = EmbeddingChatRequest -PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest] + +T = TypeVar("T") + + +class IOProcessorRequest(OpenAIBaseModel, Generic[T]): + model: Optional[str] = None + + priority: int = Field(default=0) + """ + The priority of the request (lower means earlier handling; + default: 0). Any priority other than 0 will raise an error + if the served model does not use priority scheduling. + """ + data: T + """ + When using plugins IOProcessor plugins, the actual input is processed + by the plugin itself. Hence, we use a generic type for the request data + """ + + def to_pooling_params(self): + return PoolingParams(task="encode") + + +class IOProcessorResponse(OpenAIBaseModel, Generic[T]): + + request_id: Optional[str] = None + """ + The request_id associated with this response + """ + created_at: int = Field(default_factory=lambda: int(time.time())) + + data: T + """ + When using plugins IOProcessor plugins, the actual output is generated + by the plugin itself. Hence, we use a generic type for the response data + """ + + +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, + IOProcessorRequest] class ScoreRequest(OpenAIBaseModel): @@ -1405,7 +1474,9 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class RerankRequest(OpenAIBaseModel): @@ -1435,7 +1506,9 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class RerankDocument(BaseModel): @@ -1480,7 +1553,9 @@ class CompletionResponseChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + token_ids: Optional[list[int]] = None # For response prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None # For prompt class CompletionResponse(OpenAIBaseModel): @@ -1511,6 +1586,10 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + # not part of the OpenAI spec but for tracing the tokens + # prompt tokens is put into choice to align with CompletionResponseChoice + prompt_token_ids: Optional[list[int]] = None + token_ids: Optional[list[int]] = None class CompletionStreamResponse(OpenAIBaseModel): @@ -1587,7 +1666,9 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class ClassificationData(OpenAIBaseModel): @@ -1612,7 +1693,7 @@ class FunctionCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=random_tool_call_id) + id: str = Field(default_factory=make_tool_call_id) type: Literal["function"] = "function" function: FunctionCall @@ -1680,6 +1761,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): finish_reason: Optional[str] = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but is useful for tracing the tokens + # in agent scenarios + token_ids: Optional[list[int]] = None class ChatCompletionResponse(OpenAIBaseModel): @@ -1695,6 +1779,7 @@ class ChatCompletionResponse(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.") @@ -1712,6 +1797,8 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but for tracing the tokens + token_ids: Optional[list[int]] = None class ChatCompletionStreamResponse(OpenAIBaseModel): @@ -1721,6 +1808,8 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): model: str choices: list[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + # not part of the OpenAI spec but for tracing the tokens + prompt_token_ids: Optional[list[int]] = None class TranscriptionResponseStreamChoice(OpenAIBaseModel): @@ -1778,7 +1867,7 @@ class ResponsesResponse(OpenAIBaseModel): service_tier: Literal["auto", "default", "flex", "scale", "priority"] status: ResponseStatus text: Optional[ResponseTextConfig] = None - top_logprobs: int + top_logprobs: Optional[int] = None truncation: Literal["auto", "disabled"] usage: Optional[ResponseUsage] = None user: Optional[str] = None @@ -2193,9 +2282,15 @@ class TranscriptionRequest(OpenAIBaseModel): # Transcription response objects +class TranscriptionUsageAudio(OpenAIBaseModel): + type: Literal["duration"] = "duration" + seconds: int + + class TranscriptionResponse(OpenAIBaseModel): text: str """The transcribed text.""" + usage: TranscriptionUsageAudio class TranscriptionWord(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 12349234c320fea432c94d25d991750a01e2e704..6300d0758c3d4e66faf113643bfe45fef95469cc 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -19,7 +19,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, - random_tool_call_id) + get_history_tool_calls_cnt, + make_tool_call_id) from vllm.entrypoints.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, get_streamable_parser_for_assistant, get_system_message, parse_chat_input, @@ -75,13 +76,15 @@ class OpenAIServingChat(OpenAIServing): enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack) self.response_role = response_role self.chat_template = chat_template @@ -133,6 +136,10 @@ class OpenAIServingChat(OpenAIServing): source = "model" if source == "auto" else source logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + if self.model_config.hf_config.model_type == 'kimi_k2': + self.tool_call_id_type = 'kimi_k2' + else: + self.tool_call_id_type = 'random' self.use_harmony = model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: @@ -230,7 +237,6 @@ class OpenAIServingChat(OpenAIServing): documents=request.documents, chat_template_kwargs=request.chat_template_kwargs, tool_parser=tool_parser, - truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) else: @@ -379,6 +385,7 @@ class OpenAIServingChat(OpenAIServing): current_text: Optional[str], delta_text: str, function_name_returned: bool, + tool_call_idx: Optional[int] = None ) -> tuple[Optional[DeltaMessage], bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it @@ -424,8 +431,12 @@ class OpenAIServingChat(OpenAIServing): current_tool_call = obj[-2] function_name_returned = True + tool_call_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=current_tool_call["name"], + idx=tool_call_idx) delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=random_tool_call_id(), + DeltaToolCall(id=tool_call_id, function=DeltaFunctionCall( name=current_tool_call["name"], arguments=arguments), @@ -491,6 +502,10 @@ class OpenAIServingChat(OpenAIServing): all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 # Always track previous_texts for comprehensive output logging previous_texts = [""] * num_choices @@ -568,12 +583,17 @@ class OpenAIServingChat(OpenAIServing): ), logprobs=None, finish_reason=None) + + # return prompt_token_ids at the first chunk ever chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + prompt_token_ids=(res.prompt_token_ids + if request.return_token_ids else + None)) # if continuous usage stats are requested, add it if include_continuous_usage: @@ -644,9 +664,9 @@ class OpenAIServingChat(OpenAIServing): harmony_parser = harmony_parsers[i] for token_id in output.token_ids: harmony_parser.process(token_id) - # FIXME(woosuk): Support function calling - is_final = harmony_parser.current_channel == "final" - if not (request.include_reasoning or is_final): + is_reasoning = \ + harmony_parser.current_channel == "analysis" + if not request.include_reasoning and is_reasoning: # Skip the reasoning content. continue delta_text = harmony_parser.last_content_delta or "" @@ -668,7 +688,6 @@ class OpenAIServingChat(OpenAIServing): previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text - # avoid the None + list error. if previous_token_ids: current_token_ids = previous_token_ids + as_list( @@ -677,11 +696,11 @@ class OpenAIServingChat(OpenAIServing): current_token_ids = as_list(output.token_ids) if self.use_harmony: - if is_final: - delta_message = DeltaMessage(content=delta_text) - else: + if is_reasoning: delta_message = DeltaMessage( reasoning_content=delta_text) + else: + delta_message = DeltaMessage(content=delta_text) # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] @@ -728,7 +747,7 @@ class OpenAIServingChat(OpenAIServing): index=i) else: delta_tool_call = DeltaToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, @@ -759,7 +778,11 @@ class OpenAIServingChat(OpenAIServing): previous_text=previous_text, current_text=content, delta_text=delta_text, - function_name_returned=fn_name_returned)) + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt)) + if (delta_message and delta_message.tool_calls and + delta_message.tool_calls[0].id is not None): + history_tool_call_cnt += 1 # update the previous values for the next iteration previous_texts[i] = current_text @@ -865,7 +888,8 @@ class OpenAIServingChat(OpenAIServing): delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if tool_choice_auto or self.reasoning_parser: + if ((tool_choice_auto or self.reasoning_parser) + and not self.use_harmony): assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -912,7 +936,9 @@ class OpenAIServingChat(OpenAIServing): index=i, delta=delta_message, logprobs=logprobs, - finish_reason=None) + finish_reason=None, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) # if the model is finished generating else: @@ -973,7 +999,9 @@ class OpenAIServingChat(OpenAIServing): logprobs=logprobs, finish_reason=output.finish_reason if not auto_tools_called else "tool_calls", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) finish_reason_sent[i] = True @@ -1080,6 +1108,10 @@ class OpenAIServingChat(OpenAIServing): assert final_res is not None choices: list[ChatCompletionResponseChoice] = [] + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 role = self.get_chat_request_role(request) for output in final_res.outputs: @@ -1185,17 +1217,26 @@ class OpenAIServingChat(OpenAIServing): assert content is not None tool_calls = TypeAdapter( list[FunctionDefinition]).validate_json(content) + tool_call_ids = [] + for tool_call in tool_calls: + tool_call_ids.append( + make_tool_call_id(id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt)) + history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", - reasoning_content=reasoning_content, tool_calls=[ - tool_call_class(function=FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, - ensure_ascii=False))) - for tool_call in tool_calls - ]) + tool_call_class(id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, + ensure_ascii=False))) + for i, tool_call in enumerate(tool_calls) + ], + reasoning_content=reasoning_content) # if the request doesn't use tool choice # OR specifies to not use a tool @@ -1239,7 +1280,6 @@ class OpenAIServingChat(OpenAIServing): if (tool_call_info.content and len(tool_call_info.content) > 0): ret_content = tool_call_info.content - message = ChatMessage(role=role, reasoning_content=reasoning_content, content=ret_content) @@ -1260,7 +1300,10 @@ class OpenAIServingChat(OpenAIServing): logprobs=logprobs, finish_reason="tool_calls" if auto_tools_called else output.finish_reason if output.finish_reason else "stop", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), + ) choices.append(choice_data) @@ -1301,6 +1344,8 @@ class OpenAIServingChat(OpenAIServing): choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + prompt_token_ids=(final_res.prompt_token_ids + if request.return_token_ids else None), kv_transfer_params=final_res.kv_transfer_params, ) @@ -1313,12 +1358,11 @@ class OpenAIServingChat(OpenAIServing): elif choice.message.tool_calls: # For tool calls, log the function name and arguments tool_call_descriptions = [] - for tool_call in choice.message.tool_calls: - if hasattr(tool_call.function, "name") and hasattr( - tool_call.function, "arguments"): + for tc in choice.message.tool_calls: + if hasattr(tc.function, "name") and hasattr( + tc.function, "arguments"): tool_call_descriptions.append( - f"{tool_call.function.name}({tool_call.function.arguments})" - ) + f"{tc.function.name}({tc.function.arguments})") tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" @@ -1469,4 +1513,9 @@ class OpenAIServingChat(OpenAIServing): # Render prompt token ids. prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 377f7f6847179fe8db27e49681fdc4514cec203b..b4fdc36390319dc4eb0573fe7ccd0c711123e814 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -61,7 +61,6 @@ class ClassificationMixin(OpenAIServing): ctx.request, ctx.tokenizer, ctx.request.input, - truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, ) return None @@ -129,12 +128,14 @@ class ServingClassification(ClassificationMixin): models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, + log_error_stack=log_error_stack, ) async def create_classify( @@ -155,18 +156,6 @@ class ServingClassification(ClassificationMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ClassificationServeContext, - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 22c6b6250394c0775652424747748a572c84ca99..11effba8f9eb37ee807198e7e8eaa805a44c7a2e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -42,7 +42,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import merge_async_iterators +from vllm.utils import as_list, merge_async_iterators logger = init_logger(__name__) @@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServing): return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__( engine_client=engine_client, @@ -67,6 +68,7 @@ class OpenAIServingCompletion(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = ( @@ -125,13 +127,16 @@ class OpenAIServingCompletion(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = await self.engine_client.get_tokenizer(lora_request + ) request_prompts, engine_prompts = await self._preprocess_completion( request, tokenizer, request.prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) except ValueError as e: @@ -365,6 +370,11 @@ class OpenAIServingCompletion(OpenAIServing): for output in res.outputs: i = output.index + prompt_idx * num_choices + # Useful when request.return_token_ids is True + # Returning prompt token IDs shares the same logic + # with the echo implementation. + prompt_token_ids_to_return: Optional[list[int]] = None + assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None @@ -385,6 +395,7 @@ class OpenAIServingCompletion(OpenAIServing): *(prompt_logprobs or []), *(output.logprobs or []), ] + prompt_token_ids_to_return = prompt_token_ids has_echoed[i] = True else: # return just the delta @@ -392,6 +403,12 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = output.token_ids out_logprobs = output.logprobs + # has_echoed[i] is reused here to indicate whether + # we have already returned the prompt token IDs. + if not has_echoed[i]: + prompt_token_ids_to_return = prompt_token_ids + has_echoed[i] = True + if (not delta_text and not delta_token_ids and not previous_num_tokens[i]): # Chunked prefill case, don't return empty chunks @@ -428,6 +445,9 @@ class OpenAIServingCompletion(OpenAIServing): logprobs=logprobs, finish_reason=finish_reason, stop_reason=stop_reason, + prompt_token_ids=prompt_token_ids_to_return, + token_ids=(as_list(output.token_ids) if + request.return_token_ids else None), ) ], ) @@ -548,6 +568,10 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason=output.finish_reason, stop_reason=output.stop_reason, prompt_logprobs=final_res.prompt_logprobs, + prompt_token_ids=(prompt_token_ids + if request.return_token_ids else None), + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), ) choices.append(choice_data) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 9dcad8e391c68adb5cc893924812a7278ff114e2..0a0d98db2d0d89c888a9d4fdb469cd0f5c0030e2 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -97,7 +97,6 @@ class EmbeddingMixin(OpenAIServing): # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, add_special_tokens=ctx.request.add_special_tokens, ) else: @@ -106,7 +105,6 @@ class EmbeddingMixin(OpenAIServing): ctx.request, tokenizer, ctx.request.input, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, add_special_tokens=ctx.request.add_special_tokens, ) return None @@ -593,11 +591,13 @@ class OpenAIServingEmbedding(EmbeddingMixin): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format @@ -629,18 +629,6 @@ class OpenAIServingEmbedding(EmbeddingMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ServeContext[EmbeddingRequest], - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 29a6f76805dce1be9b8b9968dc3c774991f3c577..a20c4eda2c82fef2223991b3c1d12a86af6ca3ec 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,6 +5,7 @@ import io import json import sys import time +import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus @@ -48,9 +49,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, ErrorInfo, - ErrorResponse, PoolingResponse, - RerankRequest, ResponsesRequest, - ScoreRequest, ScoreResponse, + ErrorResponse, + IOProcessorRequest, + PoolingResponse, RerankRequest, + ResponsesRequest, ScoreRequest, + ScoreResponse, TokenizeChatRequest, TokenizeCompletionRequest, TokenizeResponse, @@ -66,7 +69,7 @@ from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin - MultiModalDataDict) + MultiModalDataDict, MultiModalUUIDDict) from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -90,7 +93,7 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, - ResponsesRequest] + ResponsesRequest, IOProcessorRequest] AnyResponse = Union[ CompletionResponse, @@ -166,7 +169,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, # Shared across most requests tokenizer: Optional[AnyTokenizer] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # `protected_namespaces` resolves Pydantic v2's warning # on conflict with protected namespace "model_" @@ -207,6 +209,7 @@ class OpenAIServing: request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__() @@ -228,6 +231,7 @@ class OpenAIServing: self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} + self.log_error_stack = log_error_stack def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ @@ -300,14 +304,12 @@ class OpenAIServing: truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - if truncate_prompt_tokens is not None: - if truncate_prompt_tokens <= self.max_model_len: - ctx.truncate_prompt_tokens = truncate_prompt_tokens - else: - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") return None def _create_pooling_params( @@ -418,6 +420,12 @@ class OpenAIServing: message: str, err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + if self.log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() return ErrorResponse(error=ErrorInfo( message=message, type=err_type, code=status_code.value)) @@ -523,9 +531,8 @@ class OpenAIServing: async def _normalize_prompt_text_to_input( self, request: AnyRequest, - tokenizer: AnyTokenizer, prompt: str, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], + tokenizer: AnyTokenizer, add_special_tokens: bool, ) -> TextTokensPrompt: async_tokenizer = self._get_async_tokenizer(tokenizer) @@ -535,6 +542,9 @@ class OpenAIServing: "do_lower_case", False)): prompt = prompt.lower() + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + if truncate_prompt_tokens is None: encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens) @@ -564,11 +574,11 @@ class OpenAIServing: async def _normalize_prompt_tokens_to_input( self, request: AnyRequest, - tokenizer: AnyTokenizer, prompt_ids: list[int], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + tokenizer: Optional[AnyTokenizer], ) -> TextTokensPrompt: - async_tokenizer = self._get_async_tokenizer(tokenizer) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) if truncate_prompt_tokens is None: input_ids = prompt_ids @@ -577,7 +587,11 @@ class OpenAIServing: else: input_ids = prompt_ids[-truncate_prompt_tokens:] - input_text = await async_tokenizer.decode(input_ids) if self.tokenizer_mode != "cpm" else await self.tokenizer.decode_all(input_ids) + if tokenizer is None: + input_text = "" + else: + async_tokenizer = self._get_async_tokenizer(tokenizer) + input_text = await async_tokenizer.decode(input_ids) if self.tokenizer_mode != "cpm" else await self.tokenizer.decode_all(input_ids) return self._validate_input(request, input_ids, input_text) @@ -651,7 +665,6 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_input: Union[str, list[int]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> TextTokensPrompt: """ @@ -663,7 +676,6 @@ class OpenAIServing: request, tokenizer, [prompt_input], - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ): return result @@ -674,7 +686,6 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> AsyncGenerator[TextTokensPrompt, None]: """ @@ -682,30 +693,27 @@ class OpenAIServing: [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] that assumes multiple inputs. """ - for text in prompt_inputs: - if isinstance(text, str): + for prompt in prompt_inputs: + if isinstance(prompt, str): yield await self._normalize_prompt_text_to_input( request, - tokenizer, - prompt=text, - truncate_prompt_tokens=truncate_prompt_tokens, + prompt=prompt, + tokenizer=tokenizer, add_special_tokens=add_special_tokens, ) else: yield await self._normalize_prompt_tokens_to_input( request, - tokenizer, - prompt_ids=text, - truncate_prompt_tokens=truncate_prompt_tokens, + prompt_ids=prompt, + tokenizer=tokenizer, ) async def _tokenize_prompt_input_or_inputs_async( self, request: AnyRequest, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: """ @@ -718,6 +726,12 @@ class OpenAIServing: inputs_embeds = list[EmbedsPrompt]() inputs_text = list[TextTokensPrompt]() + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + + if (truncate_prompt_tokens or 0) < 0: + truncate_prompt_tokens = self.max_model_len + if (isinstance(request, CompletionRequest) and request.prompt_embeds is not None): inputs_embeds.extend( @@ -741,18 +755,16 @@ class OpenAIServing: tasks = [] for prompt_input in batch_inputs: if prompt_input["is_tokens"] is False: + assert tokenizer is not None, \ + "Tokenizer is required for text prompts" task = self._normalize_prompt_text_to_input( request, - tokenizer, prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, + tokenizer=tokenizer, add_special_tokens=add_special_tokens) else: task = self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens) + request, prompt_input["content"], tokenizer=tokenizer) tasks.append(task) # Wait for all tokenization tasks to complete @@ -767,9 +779,8 @@ class OpenAIServing: request: Union[DetokenizeRequest, EmbeddingCompletionRequest, RerankRequest, ClassificationRequest, ScoreRequest, TokenizeCompletionRequest], - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., add_special_tokens: bool = ..., ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: ... @@ -778,10 +789,9 @@ class OpenAIServing: async def _preprocess_completion( self, request: CompletionRequest, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., add_special_tokens: bool = ..., ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ EngineTokensPrompt, EngineEmbedsPrompt]]]: @@ -790,10 +800,9 @@ class OpenAIServing: async def _preprocess_completion( self, request: CompletionLikeRequest, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> tuple[Union[list[TextTokensPrompt], list[Union[ TextTokensPrompt, EmbedsPrompt]]], Union[ @@ -810,7 +819,6 @@ class OpenAIServing: request, tokenizer, input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) @@ -863,7 +871,6 @@ class OpenAIServing: documents: Optional[list[dict[str, str]]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], list[EngineTokensPrompt]]: @@ -938,7 +945,6 @@ class OpenAIServing: request, tokenizer, request_prompt, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 38745d001ade689bac6406e87c4d81b36d25a466..685c98c817c3dec4bc536661b611a651919992f3 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -4,7 +4,7 @@ import asyncio import base64 import time -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from typing import Final, Literal, Optional, Union, cast import jinja2 @@ -13,19 +13,25 @@ import torch from fastapi import Request from typing_extensions import assert_never -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger +# yapf: disable from vllm.entrypoints.openai.protocol import (ErrorResponse, + IOProcessorRequest, + IOProcessorResponse, PoolingChatRequest, + PoolingCompletionRequest, PoolingRequest, PoolingResponse, PoolingResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.plugins.io_processors import get_io_processor from vllm.utils import merge_async_iterators logger = init_logger(__name__) @@ -52,26 +58,30 @@ class OpenAIServingPooling(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, + vllm_config: VllmConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, - model_config=model_config, + model_config=vllm_config.model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + io_processor_plugin = self.model_config.io_processor_plugin + self.io_processor = get_io_processor(vllm_config, io_processor_plugin) async def create_pooling( self, request: PoolingRequest, raw_request: Optional[Request] = None, - ) -> Union[PoolingResponse, ErrorResponse]: + ) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]: """ See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. @@ -80,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing): if error_check_ret is not None: return error_check_ret - encoding_format = request.encoding_format - if request.dimensions is not None: - return self.create_error_response( - "dimensions is currently not supported") - model_name = self._get_model_name(request.model) + request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) - truncate_prompt_tokens = request.truncate_prompt_tokens - + is_io_processor_request = isinstance(request, IOProcessorRequest) try: - truncate_prompt_tokens = _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens) lora_request = self._maybe_get_adapters(request) if self.model_config.skip_tokenizer_init: @@ -102,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer(lora_request ) - if isinstance(request, PoolingChatRequest): + if getattr(request, "dimensions", None) is not None: + return self.create_error_response( + "dimensions is currently not supported") + + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + truncate_prompt_tokens = _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens) + + if is_io_processor_request: + if self.io_processor is None: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details.") + + validated_prompt = self.io_processor.parse_request(request) + + engine_prompts = await self.io_processor.pre_process_async( + prompt=validated_prompt, request_id=request_id) + request_prompts: Sequence[RequestPrompt] = [ + "" + ] * len(engine_prompts) + + elif isinstance(request, PoolingChatRequest): ( _, request_prompts, @@ -118,18 +146,19 @@ class OpenAIServingPooling(OpenAIServing): # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) - else: + elif isinstance(request, PoolingCompletionRequest): (request_prompts, engine_prompts) = await self._preprocess_completion( request, tokenizer, request.input, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) + else: + raise ValueError( + f"Unsupported request of type {type(request)}") except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) @@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing): result_generator = merge_async_iterators(*generators) + if is_io_processor_request: + assert self.io_processor is not None + output = await self.io_processor.post_process_async( + model_output=result_generator, + request_id=request_id, + ) + return self.io_processor.output_to_response(output) + + assert isinstance(request, + (PoolingCompletionRequest, PoolingChatRequest)) num_prompts = len(engine_prompts) # Non-streaming response @@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing): request_id, created_time, model_name, - encoding_format, + request.encoding_format, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 86c16df40e693afd25271bcc22b7d69e4519a566..899cb07b2b37d3fa5834bd418713d69129eabde9 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -4,11 +4,11 @@ import asyncio import json import time -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AsyncExitStack from copy import copy from http import HTTPStatus -from typing import Any, Callable, Final, Optional, Union +from typing import Callable, Final, Optional, Union import jinja2 import openai.types.responses as openai_responses_types @@ -25,6 +25,8 @@ from openai.types.responses import (ResponseCreatedEvent, ResponseReasoningItem, ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent) +from openai.types.responses.response_output_text import (Logprob, + LogprobTopLogprob) # yapf: enable from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent) @@ -59,6 +61,8 @@ from vllm.logger import init_logger from vllm.outputs import CompletionOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob as SampleLogprob +from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -84,6 +88,7 @@ class OpenAIServingResponses(OpenAIServing): enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -92,6 +97,7 @@ class OpenAIServingResponses(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.chat_template = chat_template @@ -201,6 +207,12 @@ class OpenAIServingResponses(OpenAIServing): # (i.e., their request's `store=True` just because it's the default # value). request.store = False + if self.use_harmony and request.is_include_output_logprobs(): + return self.create_error_response( + err_type="invalid_request_error", + message="logprobs are not supported with gpt-oss models", + status_code=HTTPStatus.BAD_REQUEST, + ) # Handle the previous response ID. prev_response_id = request.previous_response_id @@ -238,10 +250,10 @@ class OpenAIServingResponses(OpenAIServing): raw_request.state.request_metadata = request_metadata if self.tool_server is not None and isinstance( - self.tool_server, MCPToolServer - ) and (request.background or request.stream) and request.tools and any( - tool.type in ["web_search_preview", "code_interpreter"] - for tool in request.tools): + self.tool_server, + MCPToolServer) and request.stream and request.tools and any( + tool.type in ["web_search_preview", "code_interpreter"] + for tool in request.tools): return self.create_error_response( "MCP tool server is not supported in background mode and " "streaming mode") @@ -255,114 +267,70 @@ class OpenAIServingResponses(OpenAIServing): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): builtin_tool_list.append("python") - async with AsyncExitStack() as exit_stack: - try: - if self.tool_server is not None: - # TODO: initialize tool sessions lazily when the session - # is actually used. - tool_session_ctxs: dict[str, Any] = { - tool_name: - exit_stack.enter_async_context( - self.tool_server.new_session(tool_name)) - for tool_name in builtin_tool_list - } - tool_sessions = {} - for tool_name in builtin_tool_list: - tool_sessions[tool_name] = ( - await tool_session_ctxs[tool_name]) - else: - assert len(builtin_tool_list) == 0 - tool_sessions = {} - for i, engine_prompt in enumerate(engine_prompts): - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) - - trace_headers = (None if raw_request is None else await - self._get_trace_headers( - raw_request.headers)) - - context: ConversationContext - if self.use_harmony: - if request.stream: - context = StreamingHarmonyContext( - messages, tool_sessions) - else: - context = HarmonyContext(messages, tool_sessions) - else: - context = SimpleContext() - generator = self._generate_with_builtin_tools( - request_id=request.request_id, - request_prompt=request_prompts[i], - engine_prompt=engine_prompt, - sampling_params=sampling_params, - context=context, - lora_request=lora_request, - priority=request.priority, - trace_headers=trace_headers, - ) - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - assert len(generators) == 1 - result_generator, = generators - - # Store the input messages. - if request.store: - self.msg_store[request.request_id] = messages - - if request.background: - created_time = int(time.time()) - response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="queued", - usage=None, + if self.tool_server is not None: + available_tools = builtin_tool_list + else: + assert len(builtin_tool_list) == 0 + available_tools = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext( + messages, available_tools) + else: + context = HarmonyContext(messages, available_tools) + else: + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + request_prompt=request_prompts[i], + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, + lora_request=lora_request, + priority=request.priority, + trace_headers=trace_headers, ) - async with self.response_store_lock: - self.response_store[response.id] = response + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) - # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) + assert len(generators) == 1 + result_generator, = generators - # For cleanup. - response_id = response.id - self.background_tasks[response_id] = task - task.add_done_callback( - lambda _: self.background_tasks.pop(response_id, None)) - return response + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages - if request.stream: - return self.responses_stream_generator( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - ) + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, + ) + async with self.response_store_lock: + self.response_store[response.id] = response - try: - return await self.responses_full_generator( + # Run the request in the background. + task = asyncio.create_task( + self._run_background_request( request, sampling_params, result_generator, @@ -370,10 +338,41 @@ class OpenAIServingResponses(OpenAIServing): model_name, tokenizer, request_metadata, - ) - except Exception as e: - return self.create_error_response(str(e)) - return self.create_error_response("Should not reach here") + created_time, + ), + name=f"create_{response.id}", + ) + + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None)) + return response + + if request.stream: + return self.responses_stream_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + + try: + return await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + except Exception as e: + return self.create_error_response(str(e)) async def _make_request( self, @@ -408,6 +407,11 @@ class OpenAIServingResponses(OpenAIServing): request, prev_response) prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] async def responses_full_generator( @@ -424,14 +428,16 @@ class OpenAIServingResponses(OpenAIServing): if created_time is None: created_time = int(time.time()) - try: - async for _ in result_generator: - pass - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + async with AsyncExitStack() as exit_stack: + try: + await context.init_tool_sessions(self.tool_server, exit_stack) + async for _ in result_generator: + pass + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) if self.use_harmony: assert isinstance(context, HarmonyContext) @@ -486,6 +492,51 @@ class OpenAIServingResponses(OpenAIServing): self.response_store[response.id] = response return response + def _topk_logprobs(self, logprobs: dict[int, + SampleLogprob], top_logprobs: int, + tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]: + """Returns the top-k logprobs from the logprobs dictionary.""" + out = [] + for i, (token_id, _logprob) in enumerate(logprobs.items()): + if i >= top_logprobs: + break + text = _logprob.decoded_token if _logprob.decoded_token \ + is not None else tokenizer.decode([token_id]) + out.append( + LogprobTopLogprob( + token=text, + logprob=max(_logprob.logprob, -9999.0), + bytes=list(text.encode("utf-8", errors="replace")), + )) + return out + + def _create_response_logprobs( + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None) -> list[Logprob]: + assert logprobs is not None, "logprobs must be provided" + assert len(token_ids) == len(logprobs), ( + "token_ids and logprobs.token_ids must have the same length") + out = [] + for i, token_id in enumerate(token_ids): + logprob = logprobs[i] + token_logprob = logprob[token_id] + text = token_logprob.decoded_token if token_logprob.decoded_token \ + is not None else tokenizer.decode([token_id]) + out.append( + Logprob( + token=text, + logprob=max(token_logprob.logprob, -9999.0), + bytes=list(text.encode("utf-8", errors="replace")), + top_logprobs=self._topk_logprobs(logprob, + top_logprobs=top_logprobs, + tokenizer=tokenizer) + if top_logprobs else [], + )) + return out + def _make_response_output_items( self, request: ResponsesRequest, @@ -542,7 +593,12 @@ class OpenAIServingResponses(OpenAIServing): text=content, annotations=[], # TODO type="output_text", - logprobs=None, # TODO + logprobs=self._create_response_logprobs( + token_ids=final_output.token_ids, + logprobs=final_output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) if request.is_include_output_logprobs() else None, ) message = ResponseOutputMessage( id=f"msg_{random_uuid()}", @@ -773,7 +829,7 @@ class OpenAIServingResponses(OpenAIServing): status_code=HTTPStatus.BAD_REQUEST, ) - async def responses_stream_generator( + async def _process_streaming_events( self, request: ResponsesRequest, sampling_params: SamplingParams, @@ -782,18 +838,8 @@ class OpenAIServingResponses(OpenAIServing): model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - created_time: Optional[int] = None, + created_time: int, ) -> AsyncGenerator[str, None]: - # TODO: - # 1. Handle disconnect - - if not isinstance(context, StreamingHarmonyContext): - raise NotImplementedError( - "Streaming is not supported for responses API without Harmony." - ) - - created_time = created_time or int(time.time()) - sequence_number = 0 def _send_event(event: BaseModel): @@ -1004,7 +1050,48 @@ class OpenAIServingResponses(OpenAIServing): delta=ctx.parser.last_content_delta, sequence_number=-1, )) - + # built-in tools will be triggered on the analysis channel + # However, occasionally built-in tools will + # still be output to commentary. + elif (ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient == "python": + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=None, + container_id="auto", + outputs=None, + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInProgressEvent( + type= + "response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCodeDeltaEvent( + type="response.code_interpreter_call_code.delta", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + )) if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] if (self.tool_server is not None @@ -1100,30 +1187,6 @@ class OpenAIServingResponses(OpenAIServing): and self.tool_server.has_tool("python") and previous_item.recipient is not None and previous_item.recipient.startswith("python")): - yield _send_event( - openai_responses_types.ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( - type="code_interpreter_call", - id=current_item_id, - code="", - container_id="auto", - outputs=[], - status="in_progress", - ), - )) - yield _send_event( - openai_responses_types. - ResponseCodeInterpreterCallInProgressEvent( - type="response.code_interpreter_call.in_progress", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - )) - # TODO: do we need to add delta event here? yield _send_event( openai_responses_types. ResponseCodeInterpreterCallCodeDoneEvent( @@ -1131,7 +1194,8 @@ class OpenAIServingResponses(OpenAIServing): sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - code=previous_item.content[0].text)) + code=previous_item.content[0].text, + )) yield _send_event( openai_responses_types. ResponseCodeInterpreterCallInterpretingEvent( @@ -1187,3 +1251,31 @@ class OpenAIServingResponses(OpenAIServing): sequence_number=-1, response=final_response.model_dump(), )) + + async def responses_stream_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> AsyncGenerator[str, None]: + # TODO: + # 1. Handle disconnect + + if not isinstance(context, StreamingHarmonyContext): + raise NotImplementedError( + "Streaming is not supported for responses API without Harmony." + ) + + created_time = created_time or int(time.time()) + + async with AsyncExitStack() as exit_stack: + await context.init_tool_sessions(self.tool_server, exit_stack) + async for event_data in self._process_streaming_events( + request, sampling_params, result_generator, context, + model_name, tokenizer, request_metadata, created_time): + yield event_data diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index c246274514dbfdbd7c15788b4ca27303725506b0..847c014a11dc3bbc4d2e1c3a1d0449cc043cc324 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -7,7 +7,6 @@ from typing import Any, Optional, Union from fastapi import Request -from vllm import envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger @@ -47,11 +46,13 @@ class ServingScores(OpenAIServing): models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) async def _embedding_score( self, @@ -227,8 +228,7 @@ class ServingScores(OpenAIServing): params=default_pooling_params, lora_request=lora_request) - if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( - "token_type_ids", None)): + if (token_type_ids := engine_prompt.pop("token_type_ids", None)): pooling_params = default_pooling_params.clone() compressed = compress_token_type_ids(token_type_ids) pooling_params.extra_kwargs = { @@ -266,12 +266,14 @@ class ServingScores(OpenAIServing): request: Union[ScoreRequest, RerankRequest], request_id: str, raw_request: Optional[Request] = None, - truncate_prompt_tokens: Optional[int] = None, ) -> Union[list[PoolingRequestOutput], ErrorResponse]: lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -343,7 +345,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch @@ -391,7 +392,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 58d720474768b0069dad5b4ed5b4fe8f66dbf662..2f258255d5f16ee6b6bac51058f019ba10724448 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -39,11 +39,13 @@ class OpenAIServingTokenization(OpenAIServing): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 0d6989fe91bfa092ccf325c17e53adb72431822b..9ba58d4425221979315af73b1c1f16368d3b453f 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -32,13 +32,15 @@ class OpenAIServingTranscription(OpenAISpeechToText): *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="transcribe") + task_type="transcribe", + log_error_stack=log_error_stack) async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, @@ -88,13 +90,15 @@ class OpenAIServingTranslation(OpenAISpeechToText): *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="translate") + task_type="translate", + log_error_stack=log_error_stack) async def create_translation( self, audio_data: bytes, request: TranslationRequest, diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 01140a4bfea7edcbc8bad5aef077952978b7f922..1cbd7dba393f6c93bc851e8686a37489927c3312 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -53,12 +53,14 @@ class OpenAISpeechToText(OpenAIServing): request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack) self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) @@ -200,7 +202,22 @@ class OpenAISpeechToText(OpenAIServing): for result_generator in list_result_generator: async for op in result_generator: text += op.outputs[0].text - return cast(T, response_class(text=text)) + + if self.task_type == "transcribe": + # add usage in TranscriptionResponse. + usage = { + "type": "duration", + # rounded up as per openAI specs + "seconds": int(math.ceil(duration_s)), + } + final_response = cast(T, response_class(text=text, + usage=usage)) + else: + # no usage in response for translation task + final_response = cast( + T, response_class(text=text)) # type: ignore[call-arg] + + return final_response except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 099e456aa486f3960eceec56961236b14f2792bd..44aa1208a54c783eeb99b4aabb3ea6e4e27d0c4d 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,6 +3,7 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser +from .deepseekv31_tool_parser import DeepSeekV31ToolParser from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser @@ -18,6 +19,7 @@ from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .seed_oss_tool_parser import SeedOssToolParser from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser @@ -35,11 +37,13 @@ __all__ = [ "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", + "DeepSeekV31ToolParser", "xLAMToolParser", "MinimaxToolParser", "KimiK2ToolParser", "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "SeedOssToolParser", "Step3ToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9188190f3f0cd0434afac5de7c7f5df1580205 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("deepseek_v31") +class DeepSeekV31ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" + self.tool_calls_end_token: str = "<|tool▁calls▁end|>" + + self.tool_call_start_token: str = "<|tool▁call▁begin|>" + self.tool_call_end_token: str = "<|tool▁call▁end|>" + + self.tool_call_regex = re.compile( + r"<|tool▁call▁begin|>(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)<|tool▁call▁end|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)") + + self.stream_tool_call_name_regex = re.compile( + r"(?P<function_name>.*)<|tool▁sep|>") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "DeepSeek-V3.1 Tool parser could not locate tool call " + "start/end tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + tool_calls = [] + for match in function_call_tuples: + function_name, function_args = match + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_name, tool_args = current_tool_call_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_name = current_tool_call_name_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index da4760ad1b6428b655decec528579602ecf7ae4e..ac272b0c3b2054aea0a89c35c5d332de9829b952 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -6,7 +6,7 @@ from typing import Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser): DeltaToolCall( index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True), diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 5508ba6a3940841e5661fba34c9df065c105a0db..824b100f357b5bd616db11ddf3b508d4a8c9fd46 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -10,7 +10,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index fcc5b7edda83f1ffdb58d63e7a63935abbbda987..ac517616a95b4852c1a75b29464df134bad466ca 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index d126130ab9bc387a5d56014a87bf1d8ff6eeb844..a6ce33af6bd0078cc065b8fd4f416ecfee555a49 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 92004de030d14b2cc0f7f1fd38cff9bb7a1a38a2..6ef8fadf59ac5856914355a6a5c3703a4135976e 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 66b483d8b0f66084bd9172365f00bcebc7eaf660..3b41f6034704cfb15eafba6cfcf5c2f7f97e2444 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -222,7 +222,7 @@ class JambaToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 194a144ad576e326ae492e2bc672231ed06777be..31b19c8db4163410dc363dda70b98d5da84ce471 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -10,7 +10,7 @@ import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 226309ef293a9e28403b379e0daab7acea699810..0fd62f0b6a7f134dec3e431f8d6a266bcaf1c42d 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser): sent_tools.append({ "sent_name": False, "sent_arguments": "", - "id": random_tool_call_id(), + "id": make_tool_call_id(), }) while len(tool_ids) < tool_count: @@ -461,7 +461,8 @@ class MinimaxToolParser(ToolParser): i += 1 return boundaries - def _extract_tool_args(self, tool_content: str, args_match) -> str: + def _extract_tool_args(self, tool_content: str, + args_match: re.Match[str]) -> str: """ Extract tool arguments from tool content. diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 5501028cf36b8a1d37d82ecc92c2314409103af4..85dd56213c6ac505a6c6004b7dc1d4ee0e9f0e64 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -8,7 +8,7 @@ from typing import Any, Optional import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, @@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser): tool_calls: list[ToolCall] = [ ToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=FunctionCall( name=raw_function_call["name"], diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index 2501d6739e8f6f38e69d3a2f266eb31e5bcf618c..955813ddd340861884a20306ce34df69fe838c7f 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import ast import json import uuid from collections.abc import Sequence @@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["qwen3_coder"]) +@ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): @@ -30,6 +30,8 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] + # Override base class type - we use string IDs for tool calls + self.current_tool_id: Optional[str] = None # type: ignore self.streamed_args_for_tool: list[str] = [] # Sentinel tokens for streaming mode @@ -42,20 +44,6 @@ class Qwen3CoderToolParser(ToolParser): self.is_tool_call_started: bool = False self.failed_count: int = 0 - # Streaming state variables - self.current_tool_index: int = 0 - self.header_sent: bool = False - self.current_tool_string_id: Optional[str] = None - self.current_function_name: Optional[str] = None - self.current_param_name: Optional[str] = None - self.current_param_value: str = "" - self.param_count: int = 0 - self.in_param: bool = False - self.in_function: bool = False - self.accumulated_text: str = "" - self.json_started: bool = False - self.json_closed: bool = False - # Enhanced streaming state - reset for each new message self._reset_streaming_state() @@ -67,7 +55,8 @@ class Qwen3CoderToolParser(ToolParser): self.tool_call_function_regex = re.compile( r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) self.tool_call_parameter_regex = re.compile( - r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", + re.DOTALL) if not self.model_tokenizer: raise ValueError( @@ -84,8 +73,8 @@ class Qwen3CoderToolParser(ToolParser): "Qwen3 XML Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - logger.debug("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info("vLLM Successfully import tool parser %s !", + self.__class__.__name__) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -96,7 +85,7 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_index = 0 self.is_tool_call_started = False self.header_sent = False - self.current_tool_string_id = None + self.current_tool_id = None self.current_function_name = None self.current_param_name = None self.current_param_value = "" @@ -106,122 +95,122 @@ class Qwen3CoderToolParser(ToolParser): self.accumulated_text = "" self.json_started = False self.json_closed = False - - def _parse_xml_function_call( - self, function_call_str: str, - tools: Optional[list[ChatCompletionToolsParam]] - ) -> Optional[ToolCall]: - - def get_arguments_config(func_name: str) -> dict: - if tools is None: - return {} - for config in tools: - if not hasattr(config, "type") or not ( - hasattr(config, "function") - and hasattr(config.function, "name")): - continue - if (config.type == "function" - and config.function.name == func_name): - if not hasattr(config.function, "parameters"): - return {} - params = config.function.parameters - if isinstance(params, dict) and "properties" in params: - return params["properties"] - elif isinstance(params, dict): - return params - else: - return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + + def _get_arguments_config( + self, func_name: str, + tools: Optional[list[ChatCompletionToolsParam]]) -> dict: + """Extract argument configuration for a function.""" + if tools is None: return {} + for config in tools: + if not hasattr(config, "type") or not (hasattr( + config, "function") and hasattr(config.function, "name")): + continue + if config.type == "function" and config.function.name == func_name: + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def _convert_param_value(self, param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + """Convert parameter value based on its type in the schema.""" + # Handle null value for any type + if param_value.lower() == "null": + return None - def convert_param_value(param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: - # Handle null value for any type - if param_value.lower() == "null": - return None - - converted_value: Any - - if param_name not in param_config: - if param_config != {}: - logger.warning( - "Parsed parameter '%s' is not defined in the tool " - "parameters for tool '%s', directly returning the " - "string value.", param_name, func_name) - return param_value - - if (isinstance(param_config[param_name], dict) - and "type" in param_config[param_name]): - param_type = str( - param_config[param_name]["type"]).strip().lower() - else: - param_type = "string" - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", param_name, func_name) + return param_value + + if isinstance(param_config[param_name], + dict) and "type" in param_config[param_name]: + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif param_type.startswith("int") or param_type.startswith( + "uint") or param_type.startswith( + "long") or param_type.startswith( + "short") or param_type.startswith("unsigned"): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an " + "integer in tool '%s', degenerating to string.", + param_value, param_name, func_name) return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): - try: - converted_value = int(param_value) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not an " - "integer in tool '%s', degenerating to string.", - param_value, param_name, func_name) + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value = float(param_value) + return float_param_value if float_param_value - int( + float_param_value) != 0 else int(float_param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", param_value, + param_name, func_name) return param_value - elif (param_type.startswith("num") - or param_type.startswith("float")): + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` or `false`) in tool '%s', degenerating to " + "false.", param_value, param_name, func_name) + return param_value == "true" + else: + if param_type in ["object", "array", "arr" + ] or param_type.startswith( + "dict") or param_type.startswith("list"): try: - float_param_value = float(param_value) - converted_value = (float_param_value if float_param_value - - int(float_param_value) != 0 else - int(float_param_value)) - return converted_value - except ValueError: + param_value = json.loads(param_value) + return param_value + except (json.JSONDecodeError, TypeError, ValueError): logger.warning( - "Parsed value '%s' of parameter '%s' is not a float " - "in tool '%s', degenerating to string.", param_value, - param_name, func_name) - return param_value - elif param_type in ["boolean", "bool", "binary"]: - param_value = param_value.lower() - if param_value not in ["true", "false"]: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "boolean (`true` of `false`) in tool '%s', " - "degenerating to false.", param_value, param_name, + "Parsed value '%s' of parameter '%s' cannot be " + "parsed with json.loads in tool '%s', will try " + "other methods to parse it.", param_value, param_name, func_name) - return param_value == "true" - else: - if param_type == "object" or param_type.startswith("dict"): - try: - converted_value = json.loads(param_value) - return converted_value - except json.JSONDecodeError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "valid JSON object in tool '%s', will try other " - "methods to parse it.", param_value, param_name, - func_name) + try: + param_value = ast.literal_eval(param_value) # safer + except (ValueError, SyntaxError, TypeError): logger.warning( - "Parameter '%s' has unknown type '%s'. " - "The value will be treated as a string.", param_name, - param_type) - return param_value + "Parsed value '%s' of parameter '%s' cannot be " + "converted via Python `ast.literal_eval()` in tool " + "'%s', degenerating to string.", param_value, param_name, + func_name) + return param_value + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] - param_config = get_arguments_config(function_name) + param_config = self._get_arguments_config(function_name, tools) parameters = function_call_str[end_index + 1:] param_dict = {} - for match in self.tool_call_parameter_regex.findall(parameters): - match_text = match[0] if match[0] else match[1] + for match_text in self.tool_call_parameter_regex.findall(parameters): idx = match_text.index(">") param_name = match_text[:idx] param_value = str(match_text[idx + 1:]) @@ -231,7 +220,7 @@ class Qwen3CoderToolParser(ToolParser): if param_value.endswith("\n"): param_value = param_value[:-1] - param_dict[param_name] = convert_param_value( + param_dict[param_name] = self._convert_param_value( param_value, param_name, param_config, function_name) return ToolCall( type="function", @@ -284,8 +273,7 @@ class Qwen3CoderToolParser(ToolParser): for function_call_str in function_calls ] - # Populate prev_tool_call_arr for serving layer to set - # finish_reason + # Populate prev_tool_call_arr for serving layer to set finish_reason self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: @@ -298,8 +286,8 @@ class Qwen3CoderToolParser(ToolParser): # Extract content before tool calls content_index = model_output.find(self.tool_call_start_token) - content_index = (content_index if content_index >= 0 else - model_output.find(self.tool_call_prefix)) + idx = model_output.find(self.tool_call_prefix) + content_index = content_index if content_index >= 0 else idx content = model_output[:content_index] # .rstrip() return ExtractedToolCallInformation( @@ -324,13 +312,16 @@ class Qwen3CoderToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # If no delta text, return None unless it's an EOS token after tool - # calls + # Store request for type conversion + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools if not delta_text: # Check if this is an EOS token after all tool calls are complete - # We check for tool calls in the text even if is_tool_call_started - # is False because it might have been reset after processing all - # tools + # Check for tool calls in text even if is_tool_call_started + # is False (might have been reset after processing all tools) if (delta_token_ids and self.tool_call_end_token_id not in delta_token_ids): # Count complete tool calls @@ -339,24 +330,19 @@ class Qwen3CoderToolParser(ToolParser): # If we have completed tool calls and populated # prev_tool_call_arr - if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0): + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed - open_calls = ( - current_text.count(self.tool_call_start_token) - - current_text.count(self.tool_call_end_token)) + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) if open_calls == 0: - # Return empty delta message to allow finish_reason - # processing + # Return empty delta for finish_reason processing return DeltaMessage(content="") elif not self.is_tool_call_started and current_text: # This is a regular content response that's now complete return DeltaMessage(content="") return None - # Check if this is the first call (reset state if needed) - if not previous_text: - self._reset_streaming_state() - # Update accumulated text self.accumulated_text = current_text @@ -371,11 +357,11 @@ class Qwen3CoderToolParser(ToolParser): self.param_count = 0 self.json_started = False self.json_closed = False + self.accumulated_params = {} # Check if there are more tool calls - tool_starts_count = current_text.count( - self.tool_call_start_token) - if self.current_tool_index >= tool_starts_count: + tool_starts = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts: # No more tool calls self.is_tool_call_started = False # Continue processing next tool @@ -412,20 +398,20 @@ class Qwen3CoderToolParser(ToolParser): # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index - tool_starts: list[int] = [] + tool_start_positions: list[int] = [] idx = 0 while True: idx = current_text.find(self.tool_call_start_token, idx) if idx == -1: break - tool_starts.append(idx) + tool_start_positions.append(idx) idx += len(self.tool_call_start_token) - if self.current_tool_index >= len(tool_starts): + if self.current_tool_index >= len(tool_start_positions): # No more tool calls to process yet return None - tool_start_idx = tool_starts[self.current_tool_index] + tool_start_idx = tool_start_positions[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) @@ -438,19 +424,19 @@ class Qwen3CoderToolParser(ToolParser): # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_end = tool_text.find(">", func_start) if func_end != -1: # Found complete function name self.current_function_name = tool_text[func_start:func_end] - self.current_tool_string_id = self._generate_tool_call_id() + self.current_tool_id = self._generate_tool_call_id() self.header_sent = True self.in_function = True - # IMPORTANT: Add to prev_tool_call_arr immediately when we - # detect a tool call. This ensures + # IMPORTANT: Add to prev_tool_call_arr immediately when + # we detect a tool call. This ensures # finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name @@ -466,7 +452,7 @@ class Qwen3CoderToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, - id=self.current_tool_string_id, + id=self.current_tool_id, function=DeltaFunctionCall( name=self.current_function_name, arguments=""), type="function", @@ -496,10 +482,11 @@ class Qwen3CoderToolParser(ToolParser): # Close JSON self.json_closed = True - # Extract the complete tool call to update prev_tool_call_arr - # with final arguments. Find the function content - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + # Extract complete tool call to update + # prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: @@ -507,15 +494,17 @@ class Qwen3CoderToolParser(ToolParser): # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, request.tools if request else None) + func_content, self.streaming_request.tools + if self.streaming_request else None) if parsed_tool: - # Update existing entry in prev_tool_call_arr with - # complete arguments + # Update existing entry in + # prev_tool_call_arr with complete args for i, tool in enumerate(self.prev_tool_call_arr): - if (tool.get("name") == - parsed_tool.function.name): - self.prev_tool_call_arr[i]["arguments"] = ( - parsed_tool.function.arguments) + if tool.get( + "name") == parsed_tool.function.name: + args = parsed_tool.function.arguments + self.prev_tool_call_arr[i][ + "arguments"] = args break except Exception: pass # Ignore parsing errors during streaming @@ -530,73 +519,110 @@ class Qwen3CoderToolParser(ToolParser): # Reset state for next tool self.in_function = False self.json_closed = True + self.accumulated_params = {} return result # Look for parameters - # Count how many complete parameters we have processed - complete_params = tool_text.count(self.parameter_end_token) + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) # Check if we should start a new parameter - if not self.in_param and self.param_count < complete_params: - # Find the unprocessed parameter - # Count parameter starts - param_starts = [] - idx = 0 - while True: - idx = tool_text.find(self.parameter_prefix, idx) - if idx == -1: - break - param_starts.append(idx) - idx += len(self.parameter_prefix) - - if len(param_starts) > self.param_count: - # Process the next parameter - param_idx = param_starts[self.param_count] - param_start = param_idx + len(self.parameter_prefix) - remaining = tool_text[param_start:] - - if ">" in remaining: - # We have the complete parameter name - name_end = remaining.find(">") - self.current_param_name = remaining[:name_end] - - # Find the parameter value - value_start = param_start + name_end + 1 - value_text = tool_text[value_start:] - if value_text.startswith("\n"): - value_text = value_text[1:] - - # Find where this parameter ends - param_end_idx = value_text.find( - self.parameter_end_token) - if param_end_idx != -1: - # Complete parameter found - param_value = value_text[:param_end_idx] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - # Build complete JSON fragment for this parameter - if self.param_count == 0: - json_fragment = ( - '"' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + if (not self.in_param and self.param_count < len(param_starts) + and len(param_starts) > self.param_count): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or + # function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.function_end_token) + + if next_param_idx != -1 and (func_end_idx == -1 + or next_param_idx + < func_end_idx): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.tool_call_end_token in tool_text: + # Tool call is complete, so parameter + # must be complete too. Use all + # remaining text before function end + param_end_idx = len(value_text) else: - json_fragment = ( - ', "' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') - - self.param_count += 1 - - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), - ) - ]) - - # Continue parameter value + # Still streaming, wait for more content + return None + + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Store raw value for later processing + self.accumulated_params[ + self.current_param_name] = param_value + + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, self.current_param_name, param_config, + self.current_function_name or "") + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + if self.param_count == 0: + json_fragment = (f'"{self.current_param_name}": ' + f'{serialized_value}') + else: + json_fragment = (f', "{self.current_param_name}": ' + f'{serialized_value}') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value - Not used in the current implementation + # since we process complete parameters above if self.in_param: if self.parameter_end_token in delta_text: # End of parameter @@ -608,25 +634,42 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] - # Calculate incremental JSON + # Store complete value full_value = self.current_param_value + value_chunk - prev_escaped = (json.dumps(self.current_param_value)[1:-1] - if self.current_param_value else "") - full_escaped = json.dumps(full_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] - + self.accumulated_params[ + self.current_param_name] = full_value + + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert the parameter value to the appropriate type + converted_value = self._convert_param_value( + full_value, self.current_param_name or "", + param_config, self.current_function_name or "") + + # Serialize the converted value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + # Since we've been streaming the quoted version, + # we need to close it properly + # This is complex - for now just complete the value self.in_param = False self.current_param_value = "" + # Just close the current parameter string return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, function=DeltaFunctionCall( - arguments=delta_escaped + '"'), + arguments='"'), # Close the string quote ) ]) else: @@ -638,18 +681,18 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = (json.dumps( - self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = json.dumps( + self.current_param_value, ensure_ascii=False + )[1:-1] if self.current_param_value else "" self.current_param_value += value_chunk - full_escaped = json.dumps( - self.current_param_value)[1:-1] + full_escaped = json.dumps(self.current_param_value, + ensure_ascii=False)[1:-1] delta_escaped = full_escaped[len(prev_escaped):] if delta_escaped: @@ -661,4 +704,4 @@ class Qwen3CoderToolParser(ToolParser): ) ]) - return None + return None \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..95458f07ff2a2b99a1e22bbf404b550f1e2d6ccf --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -0,0 +1,679 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from qwen3coder xml parser, All rights reserved. +# ruff: noqa: E501 + +import ast +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("seed_oss") +class SeedOssToolParser(ToolParser): + TOOL_CALL_START = "<seed:tool_call>" + TOOL_CALL_END = "</seed:tool_call>" + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # --- streaming state --- + self._reset_streaming_state() + self.prev_tool_call_arr: list[dict] = [] + + self.tool_call_start_token: str = self.TOOL_CALL_START + self.tool_call_end_token: str = self.TOOL_CALL_END + # Sentinel tokens for streaming mode + self.tool_call_prefix: str = "<function=" + self.function_end_token: str = "</function>" + self.parameter_prefix: str = "<parameter=" + self.parameter_end_token: str = "</parameter>" + self.think_start_token: str = "<seed:think>" + self.think_end_token: str = "</seed:think>" + self.is_tool_call_started: bool = False + self.is_thinking_end: bool = False + self.failed_count: int = 0 + self._reset_streaming_state() + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Seed_Oss XML parser: tokenizer did not include " + "<seed:tool_call> or its closing tag.") + + tool_start_re = re.escape(self.tool_call_start_token) + tool_end_re = re.escape(self.tool_call_end_token) + + self.tool_call_complete_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL) + self.tool_call_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", + re.DOTALL) + + self.tool_call_function_regex = re.compile( + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + self.tool_call_parameter_regex = re.compile( + r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + + logger.info("vLLM Seed-Oss XML tool parser loaded (%s).", + self.__class__.__name__) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_id = -1 + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: + + def get_arguments_config(func_name: str) -> dict: + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not ( + hasattr(config, "function") + and hasattr(config.function, "name")): + continue + if (config.type == "function" + and config.function.name == func_name): + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def convert_param_value(param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in " + "the tool parameters for tool '%s', " + "directly returning the string value.", param_name, + func_name) + return param_value + + if (isinstance(param_config[param_name], dict) + and "type" in param_config[param_name]): + param_type = str( + param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in [ + "string", "str", "text", "varchar", "char", "enum" + ]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned")): + try: + param_value = int(param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an integer in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type.startswith("num") or param_type.startswith( + "float"): + try: + float_param_value = float(param_value) + param_value = float_param_value if float_param_value - int( + float_param_value) != 0 else int( + float_param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` of `false`) in tool '%s', degenerating to false.", + param_value, param_name, func_name) + return param_value == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + param_value = json.loads(param_value) + return param_value + except (ValueError, TypeError, json.JSONDecodeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a valid JSON " + "object in tool '%s', will try other methods to parse it.", + param_value, param_name, func_name) + try: + param_value = ast.literal_eval(param_value) + except (ValueError, SyntaxError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be converted via " + "Python `ast.literal_eval()` in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = get_arguments_config(function_name) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = convert_param_value( + param_value, param_name, param_config, function_name) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(param_dict, + ensure_ascii=False)), + ) + + def _get_function_calls(self, model_output: str) -> list[str]: + # Find all tool calls + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + match[0] if match[0] else match[1] for match in matched_ranges + ] + + # Back-off strategy if no tool_call tags found + if len(raw_tool_calls) == 0: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tool_call in raw_tool_calls: + raw_function_calls.extend( + self.tool_call_function_regex.findall(tool_call)) + + function_calls = [ + match[0] if match[0] else match[1] for match in raw_function_calls + ] + return function_calls + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # Quick check to avoid unnecessary processing + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # Check if both think start and end tokens are present + if (self.think_start_token in model_output + and self.think_end_token in model_output): + # Find the position of think end token + think_end_index = model_output.find(self.think_end_token) + len( + self.think_end_token) + # Extract content after think end token + result_content = model_output[think_end_index:] + thinking_content = model_output[:think_end_index] + else: + thinking_content = "" + result_content = model_output + + try: + function_calls = self._get_function_calls(result_content) + if len(function_calls) == 0: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_calls = [ + self._parse_xml_function_call(function_call_str, request.tools) + for function_call_str in function_calls + ] + + # Populate prev_tool_call_arr for serving layer to set finish_reason + self.prev_tool_call_arr.clear() # Clear previous calls + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) + + # Extract content before tool calls + tool_call_start_index = result_content.find( + self.tool_call_start_token) + tool_call_start_index = ( + tool_call_start_index if tool_call_start_index >= 0 else + result_content.find(self.tool_call_prefix)) + content = thinking_content + result_content[:tool_call_start_index] + + return ExtractedToolCallInformation( + tools_called=(len(tool_calls) > 0), + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + # If no delta text, return None unless + # it's an EOS token after tool calls + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + # We check for tool calls in the text even if is_tool_call_started + # is False because it might have been reset after processing all tools + if (delta_token_ids + and self.tool_call_end_token_id not in delta_token_ids): + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text)) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) + if open_calls == 0: + # Return empty delta message to allow finish_reason processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if this is the first call (reset state if needed) + if not previous_text: + self._reset_streaming_state() + + # Update accumulated text + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + + # Check if there are more tool calls + if self.current_tool_index >= current_text.count( + self.tool_call_start_token): + # No more tool calls + self.is_tool_call_started = False + # Continue processing next tool + return None + + # Check if end thinking + if (not self.is_thinking_end + and (self.think_end_token_id in delta_token_ids + or self.think_end_token in delta_text)): + self.is_thinking_end = True + + # If thinking hasn't ended yet, don't process any tool calls + if not self.is_thinking_end: + return DeltaMessage(content=delta_text) + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if (self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text): + self.is_tool_call_started = True + # Return any content before the tool call + if self.tool_call_start_token in delta_text: + content_before = delta_text[:delta_text.index( + self.tool_call_start_token)] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if (current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == ""): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + # Count tool calls we've seen vs processed + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # We're in a tool call, find the current tool call portion + # Need to find the correct tool call based on current_tool_index + # Only process tool calls after think_end_token + think_end_index = current_text.find(self.think_end_token) + len( + self.think_end_token + ) if self.think_end_token in current_text else 0 + tool_starts: list[int] = [] + idx = think_end_index + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_starts.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_starts): + # No more tool calls to process yet + return None + + tool_start_idx = tool_starts[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + tool_end_idx = current_text.find(self.tool_call_end_token, + tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[tool_start_idx:tool_end_idx + + len(self.tool_call_end_token)] + + # Looking for function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + self.current_function_name = tool_text[func_start:func_end] + self.current_tool_id = self._generate_tool_call_id( + ) # type: ignore + self.header_sent = True + self.in_function = True + + # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call + # This ensures finish_reason="tool_calls" even if parsing isn't complete + already_added = any( + tool.get("name") == self.current_function_name + for tool in self.prev_tool_call_arr) + if not already_added: + self.prev_tool_call_arr.append({ + "name": self.current_function_name, + "arguments": + "{}", # Placeholder, will be updated later + }) + + # Send header with function info + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments=""), + type="function", + ) + ]) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if (not self.json_started + and self.parameter_prefix not in delta_text): + self.json_started = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ]) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.function_end_token in tool_text: + # Close JSON + self.json_closed = True + + # Extract the complete tool call to update prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_content_end = tool_text.find(self.function_end_token, + func_start) + if func_content_end != -1: + func_content = tool_text[func_start:func_content_end] + # Parse to get the complete arguments + try: + parsed_tool = self._parse_xml_function_call( + func_content, request.tools if request else None) + if parsed_tool: + # Update existing entry in prev_tool_call_arr with complete arguments + for i, tool in enumerate(self.prev_tool_call_arr): + if tool.get( + "name") == parsed_tool.function.name: + self.prev_tool_call_arr[i]["arguments"] = ( + parsed_tool.function.arguments) + break + except Exception: + logger.warning( + "Failed to parse tool arguments during streaming.", + exc_info=True) + + result = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ]) + + # Reset state for next tool + self.in_function = False + self.json_closed = True + + return result + + # Look for parameters + # Count how many complete parameters we have processed + complete_params = tool_text.count(self.parameter_end_token) + + # Check if we should start a new parameter + if not self.in_param and self.param_count < complete_params: + # Find the unprocessed parameter + # Count parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + if len(param_starts) > self.param_count: + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find( + self.parameter_end_token) + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Build complete JSON fragment for this parameter + if self.param_count == 0: + json_fragment = ( + '"' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + else: + json_fragment = ( + ', "' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value + if self.in_param: + if self.parameter_end_token in delta_text: + # End of parameter + end_idx = delta_text.find(self.parameter_end_token) + value_chunk = delta_text[:end_idx] + + # Skip past > if at start + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + # Calculate incremental JSON + full_value = self.current_param_value + value_chunk + prev_escaped = (json.dumps(self.current_param_value)[1:-1] + if self.current_param_value else "") + full_escaped = json.dumps(full_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + self.in_param = False + self.current_param_value = "" + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"'), + ) + ]) + else: + # Continue accumulating value + value_chunk = delta_text + + # Handle first chunk after param name + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + if value_chunk: + # Stream the escaped delta + prev_escaped = (json.dumps( + self.current_param_value)[1:-1] + if self.current_param_value else "") + self.current_param_value += value_chunk + full_escaped = json.dumps( + self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + if delta_escaped: + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped), + ) + ]) + + return None diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 321718b1c950b3c09c213d9fe5d2d5ccf5166de6..484e904cd8c3687f85e4f698662bbdc3c82a43b8 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser): """ Extract tool calls for streaming mode. """ - # Simplify detection: if it begins with "[" treat it as a function call - is_function_call = (current_text.strip().startswith("[")) - - # If not a function call, return normal content - if not is_function_call: + # First, check for a definitive start of a tool call block. + # This prevents premature parsing of incomplete output. + stripped_text = current_text.strip() + preprocessed_content, preprocessed_tool_calls = ( + self.preprocess_model_output(current_text)) + + # For JSON code blocks, we need to detect them earlier, even if incomplete + has_potential_json_block = ("```json" in current_text + or "```\n[" in current_text + or "[TOOL_CALLS]" in current_text + or "<tool_call>" in current_text) + + is_tool_call_block = ( + stripped_text.startswith("[") + or stripped_text.startswith("<tool_call>") + or stripped_text.startswith("[TOOL_CALLS]") or + # Check if we have thinking tags with JSON-like content following + ("</think>[" in current_text) or + # Check if the text contains a JSON array after preprocessing + preprocessed_tool_calls is not None or + # For JSON code blocks, detect early if we see enough structure + (has_potential_json_block and '"name"' in current_text + and '"arguments"' in current_text)) + + if not is_tool_call_block: return DeltaMessage(content=delta_text) try: @@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser): # Try parsing as JSON to check for complete tool calls try: - parsed_tools = json.loads(current_text) + # Use preprocessed tool calls if available + tool_calls_text = (preprocessed_tool_calls if + preprocessed_tool_calls else current_text) + parsed_tools = json.loads(tool_calls_text) if isinstance(parsed_tools, list): # Update our tool array for next time self.prev_tool_call_arr = parsed_tools @@ -226,7 +249,7 @@ class xLAMToolParser(ToolParser): function_name = name_match.group(1) # The test expects us to send just the name first - tool_id = random_tool_call_id() + tool_id = make_tool_call_id() delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=0, @@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser): return delta # Use regex to identify tool calls in the output + # Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks + search_text = (preprocessed_tool_calls + if preprocessed_tool_calls else current_text) + + # For JSON code blocks that aren't complete yet, try to extract the JSON content + if not preprocessed_tool_calls and has_potential_json_block: + # Try to extract the JSON array from within the code block + json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)", + current_text) + if json_match: + potential_json = json_match.group(1).strip() + # Use this as search text even if it's incomplete + if potential_json.startswith("[") and ( + '"name"' in potential_json + and '"arguments"' in potential_json): + search_text = potential_json + + # Try to find complete tool names first name_pattern = r'"name"\s*:\s*"([^"]+)"' - name_matches = list(re.finditer(name_pattern, current_text)) + name_matches = list(re.finditer(name_pattern, search_text)) tool_count = len(name_matches) - # If no tools found yet, return + # If no complete tool names found, check for partial tool names if tool_count == 0: - return None + # Check if we're in the middle of parsing a tool name + partial_name_pattern = r'"name"\s*:\s*"([^"]*)' + partial_matches = list( + re.finditer(partial_name_pattern, search_text)) + if partial_matches: + # We have a partial tool name - not ready to emit yet + return None + else: + # No tools found at all + return None # Ensure our state arrays are large enough while len(self.streaming_state["sent_tools"]) < tool_count: @@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser): # First, check for the empty arguments case: "arguments": {} empty_args_pattern = ( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') - empty_args_match = re.search(empty_args_pattern, current_text) + empty_args_match = re.search(empty_args_pattern, search_text) # Check if this tool has empty arguments if empty_args_match and empty_args_match.start() > 0: @@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser): # Extract arguments for current tool using regex for non-empty arguments args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' - args_matches = list(re.finditer(args_pattern, current_text)) + args_matches = list(re.finditer(args_pattern, search_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) @@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser): # Handle transition between tools is_last_tool = current_idx == tool_count - 1 - # Find where the arguments for our current tool end - if not is_last_tool: - # If we have more tools after this one, try to find the complete argument block - next_tool_pos = current_text.find( - "},{", args_matches[current_idx].start()) - if next_tool_pos != -1: - args_end_pos = (next_tool_pos + 1 - ) # +1 to include the '}' - args_text = (current_text[args_matches[current_idx] - .start():args_end_pos]. - split('"arguments":')[1].strip()) + # For multiple tools, extract only the arguments for the current tool + if tool_count > 1: + # Parse the entire JSON structure to properly extract arguments for each tool + try: + parsed_tools = json.loads(search_text) + if isinstance( + parsed_tools, + list) and current_idx < len(parsed_tools): + current_tool = parsed_tools[current_idx] + if isinstance(current_tool.get("arguments"), + dict): + args_text = json.dumps( + current_tool["arguments"]) + else: + args_text = str( + current_tool.get("arguments", "{}")) + except (json.JSONDecodeError, KeyError, IndexError): + # Fallback to regex-based extraction + pass # If arguments haven't been sent yet sent_args = self.streaming_state["sent_tools"][ @@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser): index=current_idx, function=DeltaFunctionCall( arguments="{").model_dump( - exclude_none=True), # type: ignore + exclude_none=True), # type: ignore ) ]) return delta diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index d8905fc141245a860de0f79562733e8a3cfeb2a2..d2d7dba3ae46fa619fe1055ef48441f012b83e6c 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -313,12 +313,14 @@ def log_non_default_args(args: Union[argparse.Namespace, EngineArgs]): # Handle EngineArgs instance elif isinstance(args, EngineArgs): - default_args = EngineArgs() # Create default instance + default_args = EngineArgs(model=args.model) # Create default instance for field in dataclasses.fields(args): current_val = getattr(args, field.name) default_val = getattr(default_args, field.name) if current_val != default_val: non_default_args[field.name] = current_val + if default_args.model != EngineArgs.model: + non_default_args["model"] = default_args.model else: raise TypeError("Unsupported argument type. " \ "Must be argparse.Namespace or EngineArgs instance.") diff --git a/vllm/envs.py b/vllm/envs.py index 8b394385482ad74c5a397b8720bc15542f249228..41611a3378e72bb60812c3b52e96c947e6795858 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +import json import os import sys import tempfile @@ -42,7 +43,6 @@ if TYPE_CHECKING: VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None - VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -99,6 +99,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -131,7 +132,9 @@ if TYPE_CHECKING: VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM_E8M0: bool = True + VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" @@ -159,9 +162,12 @@ if TYPE_CHECKING: VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None + VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False # add envs VLLM_OPTEST_URLS_PORT: Optional[int] = None @@ -188,6 +194,7 @@ if TYPE_CHECKING: VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_APEX_RN: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False + VLLM_USE_LIGHT_OP: bool = False def get_default_cache_root(): @@ -491,11 +498,6 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - # If set, vllm will force flashinfer to use tensor cores; - # otherwise will use heuristic based on model architecture. - "VLLM_FLASHINFER_FORCE_TENSOR_CORES": - lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), - # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), @@ -693,11 +695,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), - # Enables torch profiler if set. Path to the directory where torch profiler - # traces are saved. Note that it must be an absolute path. + # Enables torch profiler if set. + # Both AsyncLLM's CPU traces as well as workers' + # traces (CPU & GPU) will be saved under this directory. + # Note that it must be an absolute path. "VLLM_TORCH_PROFILER_DIR": lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os - .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + .path.abspath(os.path.expanduser(os.getenv( + "VLLM_TORCH_PROFILER_DIR", ".")))), # Enable torch profiler to record shapes if set # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will @@ -797,6 +802,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")), + # Whether to use aiter triton fp8 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP8BMM": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in + ("true", "1")), + # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in @@ -979,9 +990,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. - # E8M0 is faster on B200 but may reduce accuracy. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), + # TODO(wentao): unify the two E8M0 flags after verifying the correctness. + # Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs. + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -990,6 +1004,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), + # Whether to use fused grouped_topk used for MoE expert selection. + "VLLM_USE_FUSED_MOE_GROUPED_TOPK": + lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), @@ -1068,6 +1086,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + # Specifies the thresholds of the communicated tensor sizes under which + # vllm should use flashinfer fused allreduce. The variable should be a + # JSON with the following format: + # { <world size>: <max size in mb> } + # Unspecified world sizes will fallback to + # { 2: 64, 4: 1, <everything else>: 0.5 } + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": + lambda: json.loads(os.getenv( + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")), + # MoE routing strategy selector. # See `RoutingSimulator.get_available_strategies()` # for available # strategies. @@ -1134,6 +1162,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set, it means we pre-downloaded cubin files and flashinfer will + # read the cubin files directly. + "VLLM_HAS_FLASHINFER_CUBIN": + lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), + # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. # Otherwise, uses the first available of: flashinfer cutlass GEMM, # vllm cutlass GEMM, marlin GEMM. @@ -1146,6 +1179,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ENABLE_CUDAGRAPH_GC": lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))), + # Disable padding to CUDA graph capture batch sizes. + # TODO(wentao): https://github.com/vllm-project/vllm/issues/23378 + # After the issue is fixed, we can remove this flag. + "VLLM_DISABLE_PAD_FOR_CUDAGRAPH": + lambda: bool(int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))), + # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), @@ -1179,6 +1218,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + # Whether to use pytorch symmetric memory for allreduce + "VLLM_ALLREDUCE_USE_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), + # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), @@ -1294,6 +1337,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_GLOBAL_CACHE13": lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in ("true", "1")), + + # vLLM will use lightop for moe_fused_gate and moe_align_block_size + "VLLM_USE_LIGHT_OP": + lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "False").lower() in + ("true", "1")), } # --8<-- [end:env-vars-definition] @@ -1355,10 +1403,12 @@ def compute_hash() -> str: "VLLM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ATTENTION_BACKEND", "VLLM_USE_FLASHINFER_SAMPLER", - "VLLM_FLASHINFER_FORCE_TENSOR_CORES", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_USE_DEEP_GEMM_E8M0", + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "VLLM_USE_TRTLLM_FP4_GEMM", + "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", @@ -1372,6 +1422,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_RMSNORM", "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", + "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_MOE_PADDING", diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py index 52ac4725dd72669ae2010ad1ef9cda46ea6ecbcc..6205275c18e72258af3431300ddf2809ae2e5618 100644 --- a/vllm/executor/mp_distributed_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -101,7 +101,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase): result_handler.start() self.worker_monitor.start() - # Set up signal handlers to shutdown the executor cleanly + # Set up signal handlers to shut down the executor cleanly # sometimes gc does not work well self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py index 852c8f5cffa0c5408f1a5a5f7f984b28be47e368..4ce6d8dfad2cc4d2277202c9f23e48d69a5c9513 100644 --- a/vllm/executor/msgspec_utils.py +++ b/vllm/executor/msgspec_utils.py @@ -4,11 +4,12 @@ from array import array from typing import Any, Type +from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE def encode_hook(obj: Any) -> Any: - """Custom msgspec enc hook that supports array types. + """Custom msgspec enc hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder """ @@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any: f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " f"Given array has a type code of {obj.typecode}.") return obj.tobytes() + if isinstance(obj, MultiModalKwargs): + return dict(obj) def decode_hook(type: Type, obj: Any) -> Any: - """Custom msgspec dec hook that supports array types. + """Custom msgspec dec hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder """ @@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any: deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) deserialized.frombytes(obj) return deserialized + if type is MultiModalKwargs: + return MultiModalKwargs(obj) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7abaffa54c0894be6fe17bc3abe41fa0933c942c..4b2a15afb67a7d4530381cb313e7cc1817601c54 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -10,6 +10,7 @@ import msgspec import vllm.platforms from vllm.config import ParallelConfig +from vllm.distributed import get_pp_group from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger from vllm.platforms import current_platform @@ -136,6 +137,11 @@ try: scheduler_output, intermediate_tensors) if isinstance(output, IntermediateTensors): output = scheduler_output, output + elif not get_pp_group().is_last_rank: + # Case where there are no scheduled requests + # but may still be finished requests. + assert not output or not output.req_ids + output = scheduler_output, None return output def override_env_vars(self, vars: Dict[str, str]): diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index aef7841e71b714174a18f31fed12c837a61d9aa6..e9db2a0dc13a8d21298fa1819ad7a2ab30afbfc4 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, +from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, @@ -18,6 +18,7 @@ target model. """ __all__ = [ + "DataPrompt", "TextPrompt", "TokensPrompt", "PromptType", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index d143999e88cad4ef1cd22a6492f2ab769b83fcee..377ef8ab114eed371c477550a4fc9fb7d2732f78 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -7,7 +7,8 @@ import torch from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: - from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs + from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs, + MultiModalUUIDDict) class TextPrompt(TypedDict): @@ -30,6 +31,15 @@ class TextPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + """ + Optional user-specified UUIDs for multimodal items, mapped by modality. + Lists must match the number of items per modality and may contain `None`. + For `None` entries, the hasher will compute IDs automatically; non-None + entries override the default hashes for caching, and MUST be unique per + multimodal item. + """ + cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -59,6 +69,14 @@ class TokensPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + """ + Optional user-specified UUIDs for multimodal items, mapped by modality. + Lists must match the number of items per modality and may contain `None`. + For `None` entries, the hasher will compute IDs automatically; non-None + entries override the default hashes for caching. + """ + cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -77,6 +95,16 @@ class EmbedsPrompt(TypedDict): """ +class DataPrompt(TypedDict): + """Represents generic inputs handled by IO processor plugins.""" + + data: Any + """The input data""" + + data_format: str + """The input data format""" + + SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single prompt: @@ -174,9 +202,6 @@ class TokenInputs(TypedDict): prompt_token_ids: list[int] """The token IDs of the prompt.""" - token_type_ids: NotRequired[list[int]] - """The token type IDs of the prompt.""" - prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. @@ -190,7 +215,6 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: list[int], - token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, cache_salt: Optional[str] = None, ) -> TokenInputs: @@ -200,8 +224,6 @@ def token_inputs( if prompt is not None: inputs["prompt"] = prompt - if token_type_ids is not None: - inputs["token_type_ids"] = token_type_ids if cache_salt is not None: inputs["cache_salt"] = cache_salt diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 04fa709a0f15eb4635a66cf72202a5d8e3b5220a..0f126ee251ae055a9b694e62e80565efd40f779c 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -11,8 +11,9 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs) + MultiModalInputs, MultiModalUUIDDict) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -32,12 +33,14 @@ class InputPreprocessor: model_config: ModelConfig, tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer self.mm_registry = mm_registry + self.mm_processor_cache = mm_processor_cache def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: @@ -257,7 +260,9 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -265,17 +270,22 @@ class InputPreprocessor: """ tokenizer = self._get_mm_tokenizer(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, - mm_data, - hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + return mm_processor.apply( + prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) async def _process_multimodal_async( self, @@ -284,7 +294,9 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: """ Async version of @@ -292,16 +304,22 @@ class InputPreprocessor: """ tokenizer = await self._get_mm_tokenizer_async(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) + if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, - mm_data, - hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + return mm_processor.apply( + prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) def _process_embeds( self, @@ -333,15 +351,33 @@ class InputPreprocessor: ) -> EmbedsInputs: return self._process_embeds(parsed_content) + def _truncate_inputs( + self, + inputs: list[int], + tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]: + + if not tokenization_kwargs or "truncation" not in \ + tokenization_kwargs or self.tokenizer is None: + return inputs + + max_length = tokenization_kwargs["max_length"] + + if self.tokenizer.truncation_side == "left": + return inputs[-max_length:] + else: + return inputs[:max_length] + def _process_tokens( self, parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] - token_type_ids = parsed_content.get("token_type_ids") + prompt_token_ids = self._truncate_inputs( + parsed_content["prompt_token_ids"], tokenization_kwargs) inputs: Union[TokenInputs, MultiModalInputs] if multi_modal_data := parsed_content.get("multi_modal_data"): @@ -351,13 +387,10 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: - inputs = token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - ) + inputs = token_inputs(prompt_token_ids=prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -369,10 +402,12 @@ class InputPreprocessor: parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] - token_type_ids = parsed_content.get("token_type_ids") + prompt_token_ids = self._truncate_inputs( + parsed_content["prompt_token_ids"], tokenization_kwargs) inputs: Union[TokenInputs, MultiModalInputs] if multi_modal_data := parsed_content.get("multi_modal_data"): @@ -382,13 +417,10 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: - inputs = token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - ) + inputs = token_inputs(prompt_token_ids=prompt_token_ids, ) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -400,7 +432,9 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -412,7 +446,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: prompt_token_ids = self._tokenize_prompt( @@ -435,7 +469,9 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -447,7 +483,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: prompt_token_ids = await self._tokenize_prompt_async( @@ -470,7 +506,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -479,7 +517,6 @@ class InputPreprocessor: * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts - * return_mm_hashes: whether to return multimodal hashes Returns: @@ -493,21 +530,21 @@ class InputPreprocessor: return self._process_tokens( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "text": return self._process_text( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "str": return self._process_text( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) assert_never(parsed) @@ -517,7 +554,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> SingletonInputs: """ Async version of @@ -531,21 +570,21 @@ class InputPreprocessor: return await self._process_tokens_async( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "text": return await self._process_text_async( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "str": return await self._process_text_async( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) assert_never(parsed) @@ -655,6 +694,9 @@ class InputPreprocessor: self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> EncoderDecoderInputs: """ For encoder/decoder models only: @@ -696,6 +738,7 @@ class InputPreprocessor: encoder_inputs = self._prompt_to_llm_inputs( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if (decoder_input := prompt["decoder_prompt"]) is None: decoder_inputs = None @@ -711,6 +754,7 @@ class InputPreprocessor: inputs = self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -726,6 +770,9 @@ class InputPreprocessor: self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> EncoderDecoderInputs: """ Async version of @@ -738,6 +785,7 @@ class InputPreprocessor: encoder_task = self._prompt_to_llm_inputs_async( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if (decoder_input := prompt["decoder_prompt"]) is None: @@ -747,6 +795,7 @@ class InputPreprocessor: decoder_task = self._prompt_to_llm_inputs_async( decoder_input, tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) encoder_inputs, decoder_inputs = await asyncio.gather( @@ -762,6 +811,7 @@ class InputPreprocessor: inputs = await self._prompt_to_llm_inputs_async( prompt, tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -788,7 +838,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -799,7 +851,6 @@ class InputPreprocessor: * prompt: input prompt * lora_request - * return_mm_hashes Returns: @@ -810,7 +861,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -820,7 +871,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> DecoderOnlyInputs: """ Async version of @@ -830,7 +883,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -840,17 +893,19 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder + # input prompts to encoder & decoder. return self._process_encoder_decoder_prompt( - prompt, tokenization_kwargs) + prompt, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -861,7 +916,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) async def preprocess_async( @@ -869,19 +924,22 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> ProcessorInputs: """ Async version of [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. """ if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - return await self._process_encoder_decoder_prompt_async(prompt) + # input prompts to encoder & decoder. + return await self._process_encoder_decoder_prompt_async( + prompt, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -892,5 +950,9 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) + + def clear_cache(self) -> None: + if self.mm_processor_cache is not None: + self.mm_processor_cache.clear_cache() diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ef146fdfbf97cb513722c4f5b931e5644e7c8527..f0b392e9767aec975d36f27112c5f202c7592250 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -223,20 +223,26 @@ class InputRegistry: The model is identified by ``model_config``. """ # Avoid circular import + from vllm.multimodal.cache import processor_only_cache_from_config from vllm.sequence import SequenceData if not model_config.is_multimodal_model: seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) return DummyData(seq_data=seq_data) + cache = processor_only_cache_from_config(model_config, mm_registry) + # Encoder dummy data does not contain multi-modal data if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data( - model_config, seq_len) + enc_data = mm_registry.get_encoder_dummy_data(model_config, + seq_len, + cache=cache) seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) return DummyData(seq_data=seq_data) - dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) + dec_data = mm_registry.get_decoder_dummy_data(model_config, + seq_len, + cache=cache) return DummyData( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index de5933d6d41e5732e94af31fc1dc463a25ff3585..d8503b20459f6d18e8e6dc9707a3253595760f99 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # GPTQ/AWQ elif hasattr(base_layer, "qweight"): return base_layer.qweight.device - # marlin - elif hasattr(base_layer, "B"): - return base_layer.B.device # HQQ marlin elif hasattr(base_layer, "W_q"): return base_layer.W_q.device @@ -608,7 +605,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 2 sublayers (slices) - packed together (eg. gate_proj + up_proj -> gate_up_proj). + packed together (e.g. gate_proj + up_proj -> gate_up_proj). This means we have 2 LoRAs, each applied to one half of the layer. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 4949b1c7b569d634b8e16d5e3ef311eb8f6dc021..d16ddbc0be33505aabd58a267a470d99f468cd25 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -207,6 +207,7 @@ class LoRAModel(AdapterModel): """ lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") new_embeddings_tensor_path = os.path.join( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, @@ -255,9 +256,10 @@ class LoRAModel(AdapterModel): check_unexpected_modules(f) for module in f.keys(): # noqa tensors[module] = f.get_tensor(module) - elif os.path.isfile(lora_bin_file_path): - # When a bin file is provided, we rely on config to find unexpected - # modules. + elif os.path.isfile(lora_bin_file_path) or os.path.isfile( + lora_pt_file_path): + # When a bin/pt file is provided, we rely on config to find + # unexpected modules. unexpected_modules = [] target_modules = peft_helper.target_modules if not isinstance(target_modules, list): @@ -279,7 +281,10 @@ class LoRAModel(AdapterModel): f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." f" Please verify that the loaded LoRA module is correct") - tensors = torch.load(lora_bin_file_path, + lora_file_path = (lora_bin_file_path + if os.path.isfile(lora_bin_file_path) else + lora_pt_file_path) + tensors = torch.load(lora_file_path, map_location=device, weights_only=True) else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 1f1f77e9497455d848a135730fe2d6b1cb18eece..7984fea6dac8ba81fb500572164a151f531ec50d 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -10,12 +10,15 @@ import torch.nn.functional as F from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import LazyDict import vllm.envs as envs +logger = init_logger(__name__) + @CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): @@ -373,6 +376,112 @@ class ReLUSquaredActivation(CustomOp): return self.forward_native(x) +@CustomOp.register("xielu") +class XIELU(CustomOp): + """ + Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 + If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA + Otherwise, we emit a single warning and use xIELU Python + """ + + def __init__( + self, + alpha_p_init: float = 0.8, + alpha_n_init: float = 0.8, + beta: float = 0.5, + eps: float = -1e-6, + dtype: torch.dtype = torch.bfloat16, + with_vector_loads: bool = False, + ): + super().__init__() + self.alpha_p = nn.Parameter( + torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - + 1).unsqueeze(0)) + self.alpha_n = nn.Parameter( + torch.log( + torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - + 1).unsqueeze(0)) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) + self.with_vector_loads = with_vector_loads + # Temporary until xIELU CUDA fully implemented + self._beta_scalar = float(self.beta.detach().cpu().float().item()) + self._eps_scalar = float(self.eps.detach().cpu().float().item()) + + self._xielu_cuda_obj = None + try: + import xielu.ops # noqa: F401 + + self._xielu_cuda_obj = torch.classes.xielu.XIELU() + msg = "Using experimental xIELU CUDA." + try: + from torch._dynamo import allow_in_graph + + self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) + msg += " Enabled torch._dynamo for xIELU CUDA." + except Exception as err: + msg += (f" Could not enable torch._dynamo for xIELU ({err}) - " + "this may result in slower performance.") + self._xielu_cuda_fn = self._xielu_cuda + logger.warning_once(msg) + except Exception as err: + logger.warning_once( + "CUDA-fused xIELU not available (%s) –" + " falling back to a Python version.\n" + "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`", + str(err), + ) + + def _xielu_python(self, x: torch.Tensor) -> torch.Tensor: + alpha_p = nn.functional.softplus(self.alpha_p) + alpha_n = self.beta + nn.functional.softplus(self.alpha_n) + return torch.where( + x > 0, + alpha_p * x * x + self.beta * x, + (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + + self.beta * x, + ) + + def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: + """Firewall function to prevent torch.compile from seeing .item()""" + assert self._xielu_cuda_obj is not None, ( + "XIELU CUDA object must not be None") + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions" + " but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p, + self.alpha_n, + # Temporary until xIELU CUDA fully implemented -> + # self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, + self.with_vector_loads, + ) + return result.view(original_shape) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._xielu_cuda_obj is not None and input.is_cuda: + if not torch._dynamo.is_compiling(): + return self._xielu_cuda_fn(input) + else: + logger.warning_once( + "torch._dynamo is compiling, using Python version of xIELU." + ) + return self._xielu_python(input) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. @@ -432,12 +541,25 @@ _ACTIVATION_REGISTRY = LazyDict({ lambda: nn.SiLU(), "quick_gelu": lambda: QuickGELU(), + "tanh": + lambda: nn.Tanh(), + "sigmoid": + lambda: nn.Sigmoid(), + "xielu": + lambda: XIELU(), }) def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() + + if act_fn_name.startswith("torch.nn.modules."): + activation_name = act_fn_name.split(".")[-1] + if activation_name == "identity": + return nn.Identity() + act_fn_name = activation_name + if act_fn_name not in _ACTIVATION_REGISTRY: raise ValueError( f"Activation function {act_fn_name!r} is not supported.") diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..782818f55fbc2d783f1bfea36c2df3212c6aa38c --- /dev/null +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for attention-like layers.""" +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class AttentionLayerBase(ABC): + """ + Base class for attention-like layers (Attention, Mamba, etc.) + that support the v1 engine. + + This provides a common interface for getting attention backends + from different layer types. + """ + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this layer.""" + pass diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index d9cfe96f7a033c9a22497f045b2cc9fbb3e89488..a5326dfe84f6d5de98bfe1611a74f7a60b974b0a 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_blackwell_deep_gemm_e8m0_used) + is_deep_gemm_e8m0_used) logger = init_logger(__name__) @@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm( # number of valid tokens for this expert n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) - cols = tl.arange(0, BLOCK) - cols = cols.to(tl.int64) - mask_h = cols < BLOCK + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + + cols * stride_yq_h) + base_ys_offset = e * stride_ys_e + g * stride_ys_g for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): - base_i_offset = (e * stride_i_e + t * stride_i_t + - g * GROUP_SIZE * stride_i_h) - base_yq_offset = (e * stride_yq_e + t * stride_yq_t + - g * GROUP_SIZE * stride_yq_h) - base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g - - mask = mask_h - x = tl.load(input_ptr + base_i_offset + cols * stride_i_h, - mask=mask, - other=0.0).to(tl.float32) - y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h + - cols * stride_i_h, + gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, + mask=mask, + other=0.0).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, - other=0.0).to(tl.float32) + other=0.0) + + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up - x = x * (1.0 / (1.0 + tl.exp(-x))) - y = x * y2 + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) - _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - scale_raw = _absmax / fp8_max - y_s = tl.math.exp2(tl.ceil( - tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) - tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) - tl.store(y_s_ptr + base_ys_offset, y_s) + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) def silu_mul_fp8_quant_deep_gemm( - y: torch.Tensor, # (E, T, 2*H) float32 + y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128, eps: float = 1e-10, -): +) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. Returns `(y_q, y_s)` where - * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. - * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) """ assert y.ndim == 3, "y must be (E, T, 2*H)" E, T, H2 = y.shape @@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm( stride_cnt_e = tokens_per_expert.stride()[0] - # static grid over experts and H-groups. + # Static grid over experts and H-groups. # A loop inside the kernel handles the token dim grid = (E * G, ) @@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, - is_blackwell_deep_gemm_e8m0_used(), + is_deep_gemm_e8m0_used(), BLOCK=group_size, - NUM_STAGES=8, + NUM_STAGES=4, num_warps=1, ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 6b06275e56a69df71b79e7c5680a8192aa581dc5..84924c2e78e27dc6b838437dfa062ed36c90a9e8 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -194,12 +194,6 @@ class FusedMoEParallelConfig: return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") - @property - def use_flashinfer_cutlass_kernels(self): - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") - @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -408,7 +402,14 @@ class FusedMoEConfig: @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels + """ + Whether to use FlashInfer cutlass kernels for NVFP4 MoE. + """ + return (self.quant_config is not None + and self.quant_config.quant_dtype == "nvfp4" + and envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") @staticmethod def make( @@ -454,6 +455,12 @@ class FusedMoEConfig: if quant_dtype is None and isinstance(quant_config, Fp8Config): quant_dtype = torch.float8_e4m3fn + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Config) + if (quant_dtype is None and isinstance(quant_config, Mxfp4Config) + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + quant_dtype = "mxfp8" + from vllm.model_executor.layers.quantization.modelopt import ( ModelOptNvFp4Config) if quant_dtype is None and isinstance(quant_config, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000000000000000000000000000000..63de4bfa4cb526619dc99955f211fc70770a29fb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,122 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000000000000000000000000000000..b962d19506ce5896ef71d0d260c57575c796f1e3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000000000000000000000000000000..6efcc02b4d9a2f6b6f3929343bc51fc4a1e680e1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,114 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..d677d69c57a25d0ec31d6fa791d9f06fdb9effd8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,154 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index e67ff66882102c2792291b317d6d2627453034b2..0eec93601b3f2fc30ecc17d49934c3fa1e4ba868 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -3,10 +3,115 @@ from typing import Callable, Optional import torch +from torch.nn import functional as F from vllm import envs +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights, topk_ids.to(torch.int32) + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + return grouped_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids.to(torch.int32) + else: + return custom_routing_function(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + class IPEXFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: @@ -31,12 +136,15 @@ class IPEXFusedMOE: expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input + assert routed_scaling_factor == 1.0, \ + f"routed_scaling_factor {routed_scaling_factor} is not supported." return layer.ipex_fusion( x, use_grouped_topk, @@ -56,113 +164,6 @@ class SGLFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: pass - @staticmethod - def _grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - - gating_output = gating_output.float() - if scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - num_token = scores.shape[0] - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use - # biased scores for expert selection but original scores for - # routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) - else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, - k=topk_group, - dim=-1, - sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, - -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] - - if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, - keepdim=True) - - return topk_weights, topk_ids.to(torch.int32) - - @staticmethod - def _select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = SGLFusedMOE._grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - elif custom_routing_function is None: - assert scoring_func == "softmax" - topk_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) - if renormalize: - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - topk_ids = topk_ids.to(torch.int32) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) - - return topk_weights, topk_ids - def __call__( self, layer: torch.nn.Module, @@ -177,13 +178,14 @@ class SGLFusedMOE: expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input - topk_weights, topk_ids = SGLFusedMOE._select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -193,6 +195,7 @@ class SGLFusedMOE: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, ) @@ -213,3 +216,82 @@ class SGLFusedMOE: True, ) return x + + +class CPUFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + pass + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 + len_experts = global_num_experts + + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + + layer_w13_weight = layer.w13_weight[i] + layer_w2_weight = layer.w2_weight[i] + + gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) + gate_up = silu_and_mul(gate_up) + expert_out = F.linear(gate_up, layer_w2_weight) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, + dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = (new_x.view( + *topk_ids.shape, -1).type(topk_weights.dtype).mul_( + topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype)) + return final_out diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0a02b558d09e58d88b2839243aadc429df7bf385..95d23ec0346c1c38413ebccf371bdeb2c5cfa91e 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, moe_unpermute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) from vllm.scalar_type import scalar_types @@ -34,6 +35,10 @@ def run_cutlass_moe_fp8( w2_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], @@ -41,6 +46,7 @@ def run_cutlass_moe_fp8( per_act_token: bool, per_out_ch: bool, use_batched_format: bool, + topk_weights: Optional[torch.Tensor], ): a1q = hidden_states @@ -99,6 +105,22 @@ def run_cutlass_moe_fp8( topk = local_topk_ids.size(1) local_E = w1.size(0) + if use_batched_format: + mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2)) + act_out = _resize_cache(workspace2, (local_E * padded_M, N)) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (local_E * padded_M, N)) + mm2_out = _resize_cache(workspace2, (local_E * padded_M, K)) + else: + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), + (M * topk, K)) + mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) + act_out = _resize_cache(workspace2, (M * topk, N)) + # original workspace are based on input hidden_states dtype (bf16) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (M * topk, N)) + mm2_out = _resize_cache(workspace2, (M * topk, K)) + if use_batched_format: assert expert_num_tokens is not None @@ -120,11 +142,10 @@ def run_cutlass_moe_fp8( w2_scale = w2_scale.reshape(w2_scale.size(0), -1) a1q = a1q.reshape(-1, a1q.size(2)) a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() - + # c3x get_group_gemm_starts expects int64 to avoid overflow + # during offset calculations + expert_offsets = expert_offsets.to(torch.int64) else: - expert_offsets = torch.empty((global_num_experts + 1), - dtype=torch.int32, - device=device) problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) @@ -132,84 +153,57 @@ def run_cutlass_moe_fp8( dtype=torch.int32, device=device) - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - if expert_map is not None: - a_map = torch.zeros((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - else: - a_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, a_map, - c_map, global_num_experts, N, K) - - a1q = _fp8_perm(a1q, a_map) - a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + num_expert = global_num_experts if expert_map is None \ + else expert_map.size(0) + # permuted a1q reuses workspace2 + a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( + a1q, + a1q_scale, + topk_ids, + num_expert, + local_E, + expert_map, + permuted_hidden_states=a1q_perm) expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - c_strides1 = torch.full((w1.size(0), ), - 2 * N, - device=device, - dtype=torch.int64) - ab_strides2 = torch.full((w1.size(0), ), - N, - device=device, - dtype=torch.int64) - c_strides2 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - - if use_batched_format: - c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) - c2 = _resize_cache(workspace2, (local_E * padded_M, N)) - c3 = _resize_cache(workspace13, (local_E * padded_M, K)) - else: - c1 = _resize_cache(workspace13, (M * topk, N * 2)) - c2 = _resize_cache(workspace2, (M * topk, N)) - c3 = _resize_cache(workspace13, (M * topk, K)) + ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1, + problem_sizes2, + global_num_experts, N, K) if not per_act_token and (expert_map is not None or use_batched_format): # this is necessary to avoid imprecise scale calculation caused by # random data in the unused workspace. The workspace is unused when # this rank handles only partial tokens, or when it is batched . - c1.fill_(0) + mm1_out.fill_(0) - ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, + ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets, problem_sizes1, ab_strides1, ab_strides1, c_strides1, per_act_token, per_out_ch) - activation_callable(c2, c1) + activation_callable(act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( - c2, a2_scale, use_per_token_if_dynamic=per_act_token) + act_out, + a2_scale, + use_per_token_if_dynamic=per_act_token, + output=quant_out) if expert_map is not None: - c3.fill_(0) + mm2_out.fill_(0) - ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, + ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets, problem_sizes2, ab_strides2, ab_strides2, c_strides2, per_act_token, per_out_ch) if use_batched_format: - output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) + output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True) else: - # We can't do this inplace because output may point to the same tensor - # as c3. - output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) + # for non-chunking mode the output is resized from workspace13 + # so we need to make sure mm2_out uses workspace2. + moe_unpermute(out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm) class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): @@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, ): super().__init__( @@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): block_shape=block_shape, )) self.out_dtype = out_dtype + self.ab_strides1 = ab_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides1 = c_strides1 + self.c_strides2 = c_strides2 def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. @@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, expert_num_tokens, + a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, + self.c_strides2, workspace13, workspace2, expert_num_tokens, self.out_dtype if self.out_dtype is not None else in_dtype, self.per_act_token_quant, self.per_out_ch_quant, - use_batched_format) + use_batched_format, topk_weights) class CutlassExpertsFp8(CutlassExpertsFp8Base): @@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, ): super().__init__( out_dtype, per_act_token_quant, per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, block_shape, ) @@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): def supports_expert_map(self) -> bool: return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # topk weights and reduction are fused in moe_unpermute cuda kernel + return TopKWeightAndReduceNoOP() + def workspace_shapes( self, a: torch.Tensor, @@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk, N // 2) - output = (M * topk, K) + workspace2 = (M * topk, max(N // 2, K)) + output = (M, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) @@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, ): super().__init__( out_dtype, per_act_token_quant, per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, block_shape, ) assert max_experts_per_worker > 0 @@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): assert num_dp is not None workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, + max(N // 2, K)) output = (self.max_experts_per_worker, padded_M, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) @@ -392,6 +416,10 @@ def cutlass_moe_fp8( topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, per_act_token: Optional[bool] = None, activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, @@ -419,6 +447,17 @@ def cutlass_moe_fp8( Shape: [num_experts] or [num_experts, 2N] - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. Shape: [num_experts] or [num_experts, K] + - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm. + Shape: [num_experts] + - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm. + Shape: [num_experts] + - c_strides1 (torch.Tensor): The output strides for the first gemm. + Shape: [num_experts] + - c_strides2 (torch.Tensor): The output strides for the second gemm. + Shape: [num_experts] + - per_act_token (Optional[bool]): Whether the scale is per-token or + per-tensor. + - activation (str): The activation function to use. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to @@ -450,6 +489,10 @@ def cutlass_moe_fp8( out_dtype=a.dtype, per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, ), ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 3fbe2a0bc69bba5473d4693c366e89b7dc0b55a3..feab3f74cac53317346507b8272b4fbf984e7609 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -7,6 +7,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, @@ -59,8 +61,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): per_act_token_quant=False, block_shape=None, )) - assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is " - "currently supported.") + assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( + "Only nvfp4,fp8 quantization are currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank @@ -120,7 +122,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): """ aq_m, aq_n = aq.shape workspace2 = () - output_shape = (aq_m, aq_n * 2) + output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ + torch.float8_e4m3fn else (aq_m, aq_n) workspace_dtype = a.dtype workspace1 = output_shape # The workspace is determined by `aq`, since it comes after any @@ -149,29 +152,39 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], ): - # Flashinfer CUTLASS kernel takes scalar global scales, - # min because inv_scale. - - # Ensure w1_scale and w2_scale are not None before calling view - assert w1_scale is not None and w2_scale is not None, ( - "w1_scale and w2_scale must not " - "be None for FlashInferExperts") - - quant_scales = [ - self.a1_gscale, - w1_scale.view(torch.int32), - self.g1_alphas, - self.a2_gscale, - w2_scale.view(torch.int32), - self.g2_alphas, - ] + if self.quant_dtype == torch.float8_e4m3fn: + quant_scales = [ + self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale + ] + + a1q_scale = None # not passing input_sf in fp8 + fc1_expert_weights = w1 + fc2_expert_weights = w2 + else: + # Ensure w1_scale and w2_scale are not None before calling view + assert w1_scale is not None and w2_scale is not None, ( + "w1_scale and w2_scale must not " + "be None for FlashInferExperts") + # Flashinfer CUTLASS kernel takes scalar global scales, + # min because inv_scale. + quant_scales = [ + self.a1_gscale, + w1_scale.view(torch.int32), + self.g1_alphas, + self.a2_gscale, + w2_scale.view(torch.int32), + self.g2_alphas, + ] + # FlashInfer API requires weight to be long for nvfp4 + fc1_expert_weights = w1.view(torch.long) + fc2_expert_weights = w2.view(torch.long) + _ = flashinfer_cutlass_fused_moe( input=hidden_states, token_selected_experts=topk_ids.to(torch.int), token_final_scales=topk_weights, - # FlashInfer API requires weight to be long for nvfp4 - fc1_expert_weights=w1.view(torch.long), - fc2_expert_weights=w2.view(torch.long), + fc1_expert_weights=fc1_expert_weights, + fc2_expert_weights=fc2_expert_weights, output_dtype=self.out_dtype, quant_scales=quant_scales, input_sf=a1q_scale, @@ -181,3 +194,50 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ep_rank=self.ep_rank, output=output, ) + + +def flashinfer_cutlass_moe_fp4( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + + fused_experts = mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, + a1_gscale=a1_gscale), + FlashInferExperts( + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + out_dtype=hidden_states.dtype, + quant_dtype="nvfp4", + )) + + return fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a9fb273245b74ec44b34b331bc609baeffffd02c..5d0ce94acef73c6232b1be33d4297bd067b3ba3d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -52,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1202,8 +1202,23 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \ + current_platform.is_cuda() and \ + num_expert_group <= 32 and topk <= 32 and \ + e_score_correction_bias is not None: + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor) assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") @@ -1249,9 +1264,39 @@ def grouped_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +def fused_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + topk_values, topk_indices = ops.grouped_topk( + scores, scores_with_bias.to(scores.dtype), num_expert_group, + topk_group, topk, renormalize, routed_scaling_factor) + return topk_values.to(torch.float32), topk_indices.to(torch.int32) + + def get_config_dtype_str( dtype: torch.dtype, use_int4_w4a16: Optional[bool] = False, @@ -1659,9 +1704,8 @@ def fused_experts(hidden_states: torch.Tensor, # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and use_fp8_w8a8 - and (is_blackwell_deep_gemm_e8m0_used() - or _valid_deep_gemm(hidden_states, w1, w2))): + if (allow_deep_gemm and use_fp8_w8a8 and + (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): assert apply_router_weight_on_input is False assert is_act_and_mul, ( "DeepGemm only supports is_act_and_mul=True for now.") @@ -2102,8 +2146,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0e897cc1640add4d5f3a4b5d10710c6d0744d2e9..546e60556acf6ce98440e0fd38cb41e50017ecd6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -45,6 +45,7 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, round_up) from vllm import _custom_ops as ops +from lightop import op if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -207,7 +208,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): else: return None - def init_prepare_finalize(self): + # Note: init_prepare_finalize should only be called by + # prepare_communication_buffer_for_model. + def init_prepare_finalize(self, layer: torch.nn.Module): assert self.moe is not None prepare_finalize = self.maybe_make_prepare_finalize(self.moe) @@ -218,7 +221,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): assert self.fused_experts is None, \ f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, self.moe) + experts = self.select_gemm_impl(prepare_finalize, self.moe, layer) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, @@ -228,6 +231,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): self, prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation @@ -250,6 +254,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -280,6 +285,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): prepare_finalize: FusedMoEPrepareAndFinalize, # TODO(bnell): Remove. Every layer should have an moe config object. moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): @@ -382,12 +388,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): use_prepack=True, ) elif current_platform.is_cpu(): + from vllm.model_executor.layers.fused_moe import cpu_fused_moe if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - from vllm.model_executor.layers.fused_moe import cpu_fused_moe - dtype = layer.w13_weight.dtype + from vllm.model_executor.layers.utils import ( + check_cpu_sgl_kernel) + dtype_w13 = layer.w13_weight.dtype + _, n_w13, k_w13 = layer.w13_weight.size() + dtype_w2 = layer.w2_weight.dtype + _, n_w2, k_w2 = layer.w2_weight.size() if (envs.VLLM_CPU_SGL_KERNEL - and torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16): + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)): packed_w13_weight = torch.ops._C.convert_weight_packed( layer.w13_weight) assert packed_w13_weight.size() == layer.w13_weight.size() @@ -401,7 +412,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: - raise NotImplementedError("CPU MOE only supports x86 arch.") + layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) def apply( self, @@ -417,6 +428,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -447,6 +459,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -473,6 +486,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -495,6 +509,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, enable_eplb=enable_eplb, @@ -564,6 +579,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -593,6 +609,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map, custom_routing_function, scoring_func, + routed_scaling_factor, e_score_correction_bias, apply_router_weight_on_input, activation, @@ -612,6 +629,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -650,6 +668,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -670,6 +689,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): raise NotImplementedError( "Expert score correction bias is not supported for TPU.") assert activation == "silu", f"{activation} is not supported for TPU." + assert routed_scaling_factor == 1.0, \ + f"routed_scaling_factor {routed_scaling_factor} is not supported " \ + f"for TPU." if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -688,6 +710,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): forward_native = forward_tpu elif current_platform.is_cpu(): forward_native = forward_cpu + elif current_platform.is_xpu(): + forward_native = forward_xpu else: forward_native = forward_cuda @@ -735,6 +759,26 @@ def determine_expert_map( return (local_num_experts, expert_map) +def get_compressed_expert_map(expert_map: torch.Tensor) -> str: + """ + Compresses the expert map by removing any -1 entries. + + Args: + expert_map (torch.Tensor): A tensor of shape (global_num_experts,) + mapping from global to local index. Contains -1 for experts not + assigned to the current rank. + + Returns: + str: A string mapping from local to global index. + Using str to support hashing for logging once only. + """ + global_indices = torch.where(expert_map != -1)[0] + local_indices = expert_map[global_indices] + return ", ".join( + f"{local_index.item()}->{global_index.item()}" + for local_index, global_index in zip(local_indices, global_indices)) + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -777,6 +821,7 @@ class FusedMoE(CustomOp): prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -836,6 +881,12 @@ class FusedMoE(CustomOp): ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) else: self.local_num_experts, self.expert_map = (self.global_num_experts, None) @@ -854,6 +905,7 @@ class FusedMoE(CustomOp): self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation @@ -951,7 +1003,7 @@ class FusedMoE(CustomOp): self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_parallel_config.use_flashinfer_cutlass_kernels): + or self.moe_config.use_flashinfer_cutlass_kernels): self.batched_hidden_states = torch.zeros( (moe.max_num_tokens, self.hidden_size), dtype=moe.in_dtype, @@ -1005,7 +1057,7 @@ class FusedMoE(CustomOp): @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels + return self.moe_config.use_flashinfer_cutlass_kernels def update_expert_map(self): # ep_size and ep_rank should already be updated @@ -1459,6 +1511,7 @@ class FusedMoE(CustomOp): num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, enable_eplb: bool = False, @@ -1498,15 +1551,26 @@ class FusedMoE(CustomOp): assert topk_group is not None assert num_expert_group is not None if use_fused_gate: - topk_weights, topk_ids = ops.moe_fused_gate( - router_logits, - e_score_correction_bias, - num_expert_group, - topk_group, - top_k, - routed_scaling_factor=routed_scaling_factor, - n_share_experts_fusion=0, - ) + if envs.VLLM_USE_LIGHT_OP: + topk_weights, topk_ids = op.moe_fused_gate( + router_logits, + e_score_correction_bias, + num_expert_group, + topk_group, + top_k, + 0, + routed_scaling_factor, + ) + else: + topk_weights, topk_ids = ops.moe_fused_gate( + router_logits, + e_score_correction_bias, + num_expert_group, + topk_group, + top_k, + routed_scaling_factor=routed_scaling_factor, + n_share_experts_fusion=0, + ) else: topk_weights, topk_ids = grouped_topk( hidden_states=hidden_states, @@ -1516,7 +1580,9 @@ class FusedMoE(CustomOp): num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) + if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) elif custom_routing_function is None: @@ -1683,6 +1749,7 @@ class FusedMoE(CustomOp): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, enable_eplb=self.enable_eplb, @@ -1723,7 +1790,7 @@ class FusedMoE(CustomOp): # only when data parallelism (DP) is enabled. use_flashinfer_cutlass_kernels = ( self.dp_size > 1 - and self.moe_parallel_config.use_flashinfer_cutlass_kernels) + and self.moe_config.use_flashinfer_cutlass_kernels) if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or use_flashinfer_cutlass_kernels): @@ -1732,7 +1799,7 @@ class FusedMoE(CustomOp): do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels - and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) + and not self.moe_config.use_flashinfer_cutlass_kernels) if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) @@ -1751,6 +1818,7 @@ class FusedMoE(CustomOp): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index f427f3feccf0b5151a7714c0162cf82a9613d627..a0159f65a18b63e9866110c548e7ee67f6f00545 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -8,6 +8,9 @@ from vllm import _custom_ops as ops from vllm.triton_utils import triton from vllm.utils import round_up +import vllm.envs as envs +from lightop import op + def moe_align_block_size( topk_ids: torch.Tensor, @@ -92,8 +95,12 @@ def moe_align_block_size( dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + if envs.VLLM_USE_LIGHT_OP: + op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + else: + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) if expert_map is not None: expert_ids = expert_map[expert_ids] diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index d35bd0098b3ca9cc53c116430085fbd6870fa0a6..582ae3e12c2899d50560ee87ed8e9604e84709b9 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F -import torch_xla.experimental.custom_kernel # noqa: F401 def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: @@ -41,6 +40,7 @@ def fused_moe( gating_output: [*, num_experts] """ assert expert_map is None, "expert_map is not supported for pallas MoE." + import torch_xla.experimental.custom_kernel # noqa: F401 orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel() diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index d9059f50b445939900e9c9ef6c2044d38d4f2387..16a155e7184788d43c3573842daf3544f1c228e9 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -82,7 +82,8 @@ def moe_permute( n_local_expert: int = -1, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1 + fill_invalid_expert: int = -1, + permuted_hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -95,14 +96,17 @@ def moe_permute( - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices + - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor. + If None, the output tensor will be created in this function. Returns: - permuted_hidden_states (torch.Tensor): permuted activation. - - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states + - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states + if original scale not per-tensor scaling - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. @@ -122,11 +126,16 @@ def moe_permute( 1) // align_block_size * align_block_size if n_local_expert == -1: n_local_expert = n_expert - permuted_hidden_states = torch.empty( - (permuted_row_size, n_hidden), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + if permuted_hidden_states is None: + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), ( + f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}" + f" but got {permuted_hidden_states.size()}") + token_expert_indices = torch.arange(0, n_token * topk, dtype=torch.int32, @@ -153,7 +162,8 @@ def moe_permute( align_block_size, permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, permuted_idx, m_indices) - if a1q_scale is not None: + + if a1q_scale is not None and a1q_scale.dim() > 1: a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk] return (permuted_hidden_states, a1q_scale, expert_first_token_offset, @@ -185,6 +195,7 @@ def moe_unpermute( n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, inv_permuted_idx, expert_first_token_offset, topk, out) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 931fd6cedb12cce56a04ce7fccf02813cca07c03..4b08e1ab9cc8cbfabc9b633b3b14ec7c442a5512 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -268,6 +268,7 @@ def rocm_aiter_grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] @@ -280,7 +281,7 @@ def rocm_aiter_grouped_topk( if e_score_correction_bias is not None: torch.ops.vllm.rocm_aiter_biased_grouped_topk( gating_output, - e_score_correction_bias, + e_score_correction_bias.to(gating_output.dtype), topk_weights, topk_ids, num_expert_group, @@ -299,6 +300,8 @@ def rocm_aiter_grouped_topk( scoring_func, ) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights, topk_ids @@ -410,15 +413,15 @@ def shuffle_weights( *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) ) -> tuple[torch.Tensor, ...]: """ - Applies shuffle_weight function from AITER to each + Applies shuffle_weight function from AITER to each input tensor and returns them. - + Rearranges (shuffles) the input tensor/s into a specified block layout for optimized computation. Args: *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the + layout: A pair of integers specifying the block sizes used to divide the tensors during shuffling. Default is (16, 16). diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 486ca881df48c8298c46973b72f9658f497b0ba7..6cd81d97f029866c30472861c8dccb47be310fdc 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used() + if self.allow_deep_gemm and (is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( @@ -143,7 +143,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) - or is_blackwell_deep_gemm_e8m0_used())) + or is_deep_gemm_e8m0_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..14dfce4b0e3aabd9de56ec51be81787f707dd798 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) +from vllm.utils import next_power_of_2 + + +class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + moe: FusedMoEConfig, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + w13_bias, + w2_bias, + max_capture_size, + ): + super().__init__(moe.quant_config) + self.moe = moe + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.w13_bias = w13_bias + self.w2_bias = w2_bias + self.max_capture_size = max_capture_size + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # The workspaces for this implementation are managed by flashinfer. + # TODO(varun) : workspace1 is could be used as the output tensor. This + # is error-prone. Allow the `workspace_shapes` to return None workspaces + workspace1 = (M, K) + workspace2 = (0, 0) + output = (M, K) + return (workspace1, workspace2, output, a.dtype) + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, + local_num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # 1.0 means perfect expert distribution. + # > 1.0 means some experts have more tokens than the perfect + # distribution. + # < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert assuming perfect + # distribution. + num_tokens_per_expert = (num_tokens * top_k) // local_num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the + # kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + topk = topk_ids.size(-1) + local_num_experts = w1.size(0) + intermediate_size = w2.size(1) + local_expert_offset = self.moe.ep_rank * local_num_experts + + x_quant = hidden_states + x_scale = a1q_scale + if x_scale is not None: + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *x_quant.shape[:-1], -1) + + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16).view(torch.int16) + + assert w1_scale is not None + assert w2_scale is not None + kwargs = { + "topk_ids": + packed_tensor, + "routing_bias": + None, + "hidden_states": + x_quant, + "hidden_states_scale": + x_scale, + "gemm1_weights": + w1, + "gemm1_weights_scale": + w1_scale, + "gemm1_bias": + self.w13_bias, + "gemm1_alpha": + self.gemm1_alpha, + "gemm1_beta": + self.gemm1_beta, + "gemm1_clamp_limit": + self.gemm1_clamp_limit, + "gemm2_weights": + w2, + "gemm2_weights_scale": + w2_scale, + "gemm2_bias": + self.w2_bias, + "output1_scale_scalar": + None, + "output1_scale_gate_scalar": + None, + "output2_scale_scalar": + None, + "num_experts": + global_num_experts, + "top_k": + topk, + "n_group": + None, + "topk_group": + None, + "intermediate_size": + intermediate_size, + "local_expert_offset": + local_expert_offset, + "local_num_experts": + local_num_experts, + "routed_scaling_factor": + None, + "tile_tokens_dim": + self._get_tile_tokens_dim(x_quant, topk, local_num_experts), + "routing_method_type": + 1, + "do_finalize": + True, + "output": + output, + "tune_max_num_tokens": + self.max_capture_size, + } + + from flashinfer import trtllm_fp4_block_scale_routed_moe + trtllm_fp4_block_scale_routed_moe(**kwargs) + return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 95431d7acd39274cea6ca79ab1faa08e2c8cf64f..97abc9fcc06e7c65dea2c9aabb44008cf38cdc15 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -16,6 +16,8 @@ except Exception: from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( quant_dequant_mxfp4) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + mxfp8_quantize) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv @@ -181,6 +183,18 @@ def _mxfp4_quantize( return A, None +def _mxfp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert A_scale is None + assert not per_act_token_quant + assert block_shape is None + return mxfp8_quantize(A) + + def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -199,6 +213,8 @@ def moe_kernel_quantize_input( is_sf_swizzled_layout=is_fp4_scale_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == "mxfp8": + return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 8ffc700ca5cde83e2a365ac80c52aa28eb818365..0b87acc8512080bd05a32ec892e6c0bc838b4237 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch from einops import rearrange @@ -453,7 +455,14 @@ class _attention(torch.autograd.Function): lightning_attention_ = _attention.apply -def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): +def lightning_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ed: torch.Tensor, + block_size: int = 256, + kv_history: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: """ Apply lightning attention algorithm to compute attention efficiently. diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 70c9f2ffbda02c69e819fa794ba8bb20245655b1..76e7354abe8144fb4d2fd9c4b903c59a52156033 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -38,6 +38,7 @@ logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", + "CompressedTensorsLinearTransformMethod", "BitBLASLinearMethod", "GPTQBitBLASLinearMethod", "AWQMarlinLinearMethod", @@ -45,7 +46,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", - "QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", @@ -56,6 +56,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "HQQMarlinMethod", "QuarkLinearMethod", "ModelOptNvFp4LinearMethod", + "PetitNvFp4LinearMethod", "BlockInt8LinearMethod", ] @@ -216,12 +217,12 @@ class UnquantizedLinearMethod(LinearMethodBase): set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # special postprocessing for CPU SGL if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: + from vllm.model_executor.layers.utils import check_cpu_sgl_kernel N, K = layer.weight.size() dtype = layer.weight.dtype - if (torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16 and N % 32 == 0 - and K % 32 == 0): + if check_cpu_sgl_kernel(N, K, dtype): packed_weight = torch.ops._C.convert_weight_packed( layer.weight) assert packed_weight.size() == layer.weight.size() @@ -233,7 +234,8 @@ class UnquantizedLinearMethod(LinearMethodBase): else: logger.warning( "CPU SGL kernels require Intel AMX support," - " bfloat16 weight, IC and OC are divisible by 32.") + " bf16/fp16/int8 weight, IC and OC are divisible by " + "32 and 16.") layer.use_cpu_sgl = False def apply(self, @@ -264,10 +266,10 @@ class LinearBase(CustomOp): Args: input_size: input dimension of the linear layer. output_size: output dimension of the linear layer. - bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + prefix: Prefix for parameter names. return_bias: If true, return bias together with outputs in forward pass. """ @@ -413,13 +415,14 @@ class MergedReplicatedLinear(ReplicatedLinear): Args: input_size: input dimension of the linear layer. - output_size: output dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -472,7 +475,7 @@ class MergedReplicatedLinear(ReplicatedLinear): shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] - param[shard_offset:shard_offset + shard_size] = loaded_weight + param.data[shard_offset:shard_offset + shard_size] = loaded_weight @CustomOp.register("column_parallel_linear") @@ -1452,7 +1455,7 @@ class RowParallelLinear(LinearBase): return output, output_bias def extra_repr(self) -> str: - s = f"input_features={self.input_size_per_partition}" + s = f"in_features={self.input_size_per_partition}" s += f", output_features={self.output_size}" s += f", bias={self.bias is not None}" s += f", tp_size={self.tp_size}" @@ -1543,7 +1546,7 @@ class QKVCrossParallelLinear(LinearBase): self.bias = torch.nn.Parameter() set_weight_attrs(self.bias, { "output_dim": 0, - "weight_loader": self.weight_loader, + "weight_loader": self.weight_loader_v1, }) else: self.bias = None @@ -1653,6 +1656,18 @@ class QKVCrossParallelLinear(LinearBase): k, v = kv_enc.split(self.kv_size, dim=-1) return q, k, v + def weight_loader_v1(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + # just like all other parameters, does not yet + # support loading bias with weight_loader_v2 + layer = (self.q_proj_decoder + if loaded_shard_id == "q" else self.kv_proj_encoder) + target_param = self.select_proj_params(layer, param) + shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + layer.weight_loader(target_param, loaded_weight, *shard_id_args) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index daebe46f6f7715928d952be45bd49a7e362a36fb..a524e13405807c55a281a89782590f3e693a099d 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING import torch +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -class MambaBase(ABC): +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. @@ -32,3 +38,8 @@ class MambaBase(ABC): @abstractmethod def mamba_type(self) -> str: pass + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this Mamba layer.""" + pass diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d93cef1a27ad4a1f3265e2a7d7e1f8af9547912a --- /dev/null +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -0,0 +1,442 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +from typing import TYPE_CHECKING + +import torch +import torch.distributed +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from vllm import envs +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.lightning_attn import ( + lightning_attention, linear_decode_forward_triton) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +import torch +import torch.distributed + +from vllm.model_executor.models.minimax_cache import MinimaxCacheParams + + +class MiniMaxText01RMSNormTP(CustomOp): + name = "MiniMaxText01RMSNormTP" + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.tp_world = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.weight = nn.Parameter(torch.ones(int(hidden_size / + self.tp_world))) + + self.weight.weight_loader = self.weight_loader + self.variance_epsilon = eps + return + + @staticmethod + def weight_loader( + param: nn.Parameter, + loaded_weight: torch.Tensor, + ) -> None: + tp_world = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + shard_size = loaded_weight.shape[0] // tp_world + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard]) + return + + def _forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) + if self.tp_world > 1: + variance = tensor_model_parallel_all_reduce( + variance) / self.tp_world + x = x * torch.rsqrt(variance + self.variance_epsilon) + + weight = self.weight + if x.size(-1) != self.weight.size(0): + if self.weight.size(0) < x.size(-1): + repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) + full_weight = self.weight.repeat(repeat_count) + weight = full_weight[:x.size(-1)] + else: + weight = self.weight[:x.size(-1)] + + x = x.to(orig_dtype) * weight + return x + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert residual is None, "RMSNorm does not support residual connection." + return self._forward(x) + + +class MiniMaxText01LinearKernel: + + @staticmethod + def jit_linear_forward_prefix(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + block_size: int, + layer_idx: Optional[int] = None, + **kwargs) -> torch.Tensor: + + slope_rate = slope_rate.to(torch.float32) + should_pad_dim = q.dim() == 3 + if should_pad_dim: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + b, h, n, d = q.shape + e = d + kv_history = kv_caches.reshape(1, h, d, e).contiguous() + output, kv_history = lightning_attention(q, + k, + v, + slope_rate, + block_size=block_size, + kv_history=kv_history) + kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) + assert output.shape[0] == 1, "batch size must be 1" + return rearrange(output.squeeze(0), "h n d -> n (h d)") + + +class MiniMaxText01LinearAttention(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.linear_attn import ( + LinearAttentionBackend) + return LinearAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=self.num_heads, + tp_size=self.tp_size, + head_dim=self.head_dim) + + def __init__( + self, + hidden_size: int, + hidden_inner_size: int, + num_heads: int, + head_dim: int, + max_position: int, + block_size: int, + num_hidden_layer: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = 0, + linear_layer_idx: int = 0, + prefix: str = "linear_attn", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.BLOCK = block_size + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads + self.hidden_inner_size = hidden_inner_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + self.qkv_size = self.num_heads * self.head_dim + self.tp_hidden = self.head_dim * self.tp_heads + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size * 3, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.output_gate = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_gate", + ) + self.out_proj = RowParallelLinear( + self.hidden_inner_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.norm = MiniMaxText01RMSNormTP( + self.hidden_inner_size, + eps=1e-5, + ) + + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( + self.num_heads) + if num_hidden_layer <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * (1 - layer_idx / + (num_hidden_layer - 1) + 1e-5) + self.tp_slope = self.slope_rate[self.tp_rank * + self.tp_heads:(self.tp_rank + 1) * + self.tp_heads].contiguous() + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + @staticmethod + def weight_direct_load(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + return + + @staticmethod + def _build_slope_tensor(n_attention_heads: int): + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.tensor(get_slopes(n_attention_heads), + dtype=torch.float32).reshape( + n_attention_heads, 1, 1) + return slopes + + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break + # prefills are packed at end of batch in V1 + offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 + _start = attn_metadata.query_start_loc[offset + _prefill_idx] + _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] + slot_id = state_indices_tensor[offset + _prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slice_layer_cache = kv_cache[slot_id, ...] + + out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( + qs, + ks, + vs, + slice_layer_cache, + self.tp_slope, + self.BLOCK, + layer_idx=self.layer_idx) + hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden_decode = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + if envs.VLLM_USE_V1: + hidden.insert(0, hidden_decode) + else: + hidden.append(hidden_decode) + + if not hidden: + return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) + + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + if not envs.VLLM_USE_V1: + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + num_prefills = getattr(attn_metadata, "num_prefills", 0) + slot_id = state_indices_tensor[num_prefills:] + else: + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decodes] + hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, + slot_id, 32) + return hidden + + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> None: + if not envs.VLLM_USE_V1: + self._forward(hidden_states, output, positions, kv_caches) + else: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[MinimaxCacheParams]) -> None: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1 and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) + qkv32 = qkv.to(torch.float32) + qkvact = torch.nn.functional.silu(qkv32) + qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) + q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) + if envs.VLLM_USE_V1: + if attn_metadata is not None: + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor + + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, + "num_decode_tokens", 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx + + 1] + query_len = q_end - q_start + context_len = attn_metadata.seq_lens[ + num_decode_tokens + prefill_idx] - query_len + if context_len == 0: + block_to_clear = state_indices_tensor[ + num_decode_tokens + prefill_idx] + kv_cache[block_to_clear, ...] = 0 + else: + assert kv_caches is not None + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor + + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if attn_metadata is None: + hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), + device=q.device, + dtype=q.dtype) + else: + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + + output[:num_actual_tokens], _ = self.out_proj(hidden) + + +def linear_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, + output=output, + positions=positions, + kv_caches=None) + + +def linear_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=["output"], + fake_impl=linear_attention_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 3c7322260df4343a5b1d47a070e16a2cd475cb6a..e704bfd451bce67e0ab15f1bf45f0826b713e3c1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -27,6 +30,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -183,22 +188,26 @@ class MambaMixer(MambaBase, CustomOp): def forward(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): if not envs.VLLM_USE_V1: - return CustomOp.forward(self, hidden_states, mamba_cache_params) + CustomOp.forward(self, hidden_states, output, mamba_cache_params) else: - return self.forward_cuda( + torch.ops.vllm.mamba_mixer( hidden_states, - mamba_cache_params, + output, + self.prefix, ) def forward_native(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): pass def forward_cuda(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): """ Run the Mamba-1 SSM pipeline. @@ -237,6 +246,7 @@ class MambaMixer(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes else: assert isinstance(attn_metadata, AttentionMetadata) assert mamba_cache_params is not None @@ -248,6 +258,7 @@ class MambaMixer(MambaBase, CustomOp): has_initial_states = None if context_lens_tensor is not None: has_initial_states = context_lens_tensor > 0 + num_padded_decodes = attn_metadata.num_decode_tokens # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -267,6 +278,7 @@ class MambaMixer(MambaBase, CustomOp): num_decodes = attn_metadata.num_decode_tokens # token count (=request) has_prefill = num_prefill_tokens > 0 has_decode = num_decode_tokens > 0 + num_actual_tokens = num_prefill_tokens + num_decode_tokens prefill_decode_split = split_batch_to_prefill_and_decode( hidden_states_BC, @@ -278,6 +290,7 @@ class MambaMixer(MambaBase, CustomOp): num_decode_tokens, num_prefills, num_decodes, + num_padded_decodes, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d @@ -371,7 +384,7 @@ class MambaMixer(MambaBase, CustomOp): else: out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] - return out + output[:num_actual_tokens] = out def get_state_dtype(self) -> tuple[torch.dtype]: assert self.model_config is not None @@ -394,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba1" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba1_attn import ( + Mamba1AttentionBackend) + return Mamba1AttentionBackend + def _time_proj_bias(self) -> Optional[torch.Tensor]: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() @@ -421,18 +439,27 @@ def split_batch_to_prefill_and_decode( num_decode_tokens: int, num_prefills: int, num_decodes: int, + num_padded_decodes: int, ) -> PrefillDecodeSplit: + num_actual_tokens = num_prefill_tokens + num_padded_decodes + if envs.VLLM_USE_V1: # In v1, decode tokens come first, then prefill tokens. hidden_states_BC_d, hidden_states_BC_p = torch.split( - hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1) - gate_d, gate_p = torch.split(gate, - [num_decode_tokens, num_prefill_tokens], + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], dim=-1) + + # num_padded_decodes accounts for CUDA graph padding when applicable state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, [num_decodes, num_prefills], dim=0) + state_indices_tensor[:num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0) query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_decodes if num_prefills > 0 else None) + num_padded_decodes if num_prefills > 0 else None) has_initial_states_p = has_initial_states[-num_prefills:] if ( has_initial_states is not None and num_prefills > 0) else None else: @@ -459,3 +486,32 @@ def split_batch_to_prefill_and_decode( query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, ) + + +def mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None) + + +def mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mamba_mixer", + op_func=mamba_mixer, + mutates_args=["output"], + fake_impl=mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 743e520ec8ee1768cb261b96779183e8a756fd60..bb3fdd38dbef3be9cc578dea93239fabb37a1bbe 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba2" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + def mamba_mixer2( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 66674d1a6f25157e30df50a9f27298e787234a61..280a9e45e662e26d00fc5fb1f8c6397785b1b7f6 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -54,6 +54,16 @@ class MambaStateDtypeCalculator: return (conv_state_dtype, temporal_state_dtype) + @classmethod + def short_conv_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, + model_dtype) + return (conv_state_dtype, ) + class MambaStateShapeCalculator: @@ -122,6 +132,20 @@ class MambaStateShapeCalculator: tp_world_size), head_dim, state_size) return conv_state_shape, temporal_state_shape + @classmethod + def short_conv_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + conv_dim = divide(intermediate_size, tp_world_size) + conv_state_shape = (conv_kernel - 1, conv_dim) + if not use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + return (conv_state_shape, ) + @classmethod def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): """Compute the increase in group numbers to account for diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..335191a5c82c15295b31dd4a3bae9367e2a58fd0 --- /dev/null +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionMetadata) + + +@CustomOp.register("short_conv") +class ShortConv(MambaBase, CustomOp): + + def __init__(self, + config, + dim: int, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.conv_dim = dim + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = ColumnParallelLinear( + input_size=self.L_cache, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.conv1d", + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv.weight.data = self.conv.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[dim] * 3, + bias=self.bias, + prefix=f"{prefix}.in_proj", + ) + self.out_proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.out_proj", + ) + + assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1") + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + self.kv_cache = [(torch.tensor([]), )] + + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + def forward_native( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + return + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + torch.ops.vllm.short_conv( + hidden_states, + output, + self.prefix, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + forward_context = get_forward_context() + # ShortConvAttentionMetadata contains metadata necessary for the + # short_conv triton kernels to operate in continuous batching and in + # chunked prefill modes; they are computed at top-level model forward + # since they stay the same and reused for all mamba layers in the same + # iteration. + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata + assert isinstance(attn_metadata, ShortConvAttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + + BCx, _ = self.in_proj(hidden_states) + + B, C, x = BCx.chunk(3, dim=-1) + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), + self.conv.weight.size(2)) + + if attn_metadata is None: + # V1 profile run + Bx = (B * x).contiguous() + hidden_states = C * Bx + contextualized_states, _ = self.out_proj(hidden_states) + return contextualized_states + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_decodes + num_prefill_tokens + + # NOTE: V1 puts decode before prefill + # Separate prefill and decode by splitting varlen input + # Split along token dimension + B_d, B_p = torch.split( + B[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + C_d, C_p = torch.split( + C[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + x_d, x_p = torch.split( + x[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + + conv_output_list = [] + + if has_prefill: + Bx_p = (B_p * x_p).transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata(Bx_p, query_start_loc_p, + conv_metadata) + Bx = causal_conv1d_fn(Bx_p, + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + metadata=conv_metadata, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + y = C_p * Bx + conv_output_list.append(y) + + if has_decode: + Bx_d = (B_d * x_d).contiguous() + Bx = causal_conv1d_update( + Bx_d, + conv_state, + conv_weights, + self.conv.bias, + activation=None, + conv_state_indices=state_indices_tensor_d) + y = C_d * Bx + conv_output_list.insert(0, y) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(conv_output_list) + + # Final linear projection + output[:num_actual_tokens], _ = self.out_proj(hidden_states) + + def get_state_dtype(self) -> tuple[torch.dtype, ...]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.short_conv_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...]]: + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.conv_dim, + conv_kernel=self.L_cache, + ) + + @property + def mamba_type(self) -> str: + return "short_conv" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + return ShortConvAttentionBackend + + +def short_conv( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + conv_metadata=None) + + +def short_conv_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="short_conv", + op_func=short_conv, + mutates_args=["output"], + fake_impl=short_conv_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index e2162e5cbf956425d0d667305c5fa08c58782981..62b3ee1abaca841b8814909d7322cc26141c7be7 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union, cast import torch import torch.nn as nn @@ -13,16 +13,15 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.config import ModelConfig, PoolerConfig -from vllm.model_executor.pooling_metadata import ( # noqa: E501 - PoolingMetadata as V0PoolingMetadata) -from vllm.model_executor.pooling_metadata import PoolingTensors +from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask -from vllm.utils import resolve_obj_by_qualname -from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata +from vllm.utils import current_stream, resolve_obj_by_qualname +from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata + +logger = init_logger(__name__) -PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] PoolingFn = Callable[ [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], Union[torch.Tensor, list[torch.Tensor]]] @@ -126,36 +125,23 @@ def get_prompt_lens( hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states[0].device).prompt_lens + return pooling_metadata.prompt_lens def get_prompt_token_ids( pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`") - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] + assert pooling_metadata.prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`") return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) ] def get_pooling_params( pooling_metadata: PoolingMetadata) -> list[PoolingParams]: - if isinstance(pooling_metadata, V0PoolingMetadata): - pooling_params = [p for _, p in pooling_metadata.seq_groups] - else: - pooling_params = pooling_metadata.pooling_params + pooling_params = pooling_metadata.pooling_params return pooling_params @@ -172,6 +158,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: def get_classification_activation_function(config: PretrainedConfig): + # Implement alignment with transformers ForSequenceClassificationLoss + # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 + problem_type = getattr(config, "problem_type", "") + if problem_type == "regression": + return PoolerIdentity() + if problem_type == "single_label_classification": + return PoolerClassify() + if problem_type == "multi_label_classification": + return PoolerMultiLabelClassify() return PoolerClassify() @@ -191,11 +186,18 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): fn = resolve_obj_by_qualname(function_name)() return PoolerActivation.wraps(fn) - return PoolerScore() + return PoolerClassify() def build_output( all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: + # Pooling models D2H & synchronize occurs here + if isinstance(all_data, list): + all_data = [d.to("cpu", non_blocking=True) for d in all_data] + else: + all_data = all_data.to("cpu", non_blocking=True) + current_stream().synchronize() + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] return PoolerOutput(outputs=all_outputs) @@ -222,40 +224,21 @@ class PoolingMethod(nn.Module, ABC): def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return PoolingParamsUpdate() - @abstractmethod - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Note: - `prompt_len=None` means `prompt_len=len(hidden_states)`. - """ - raise NotImplementedError - @abstractmethod def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: raise NotImplementedError def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - - if isinstance(hidden_states, list): - return [ - self.forward_one(h, prompt_len) - for h, prompt_len in zip(hidden_states, prompt_lens) - ] - - return self.forward_all(hidden_states, prompt_lens) + pooling_cursor = pooling_metadata.pooling_cursor + return self.forward_all(hidden_states, pooling_cursor) class CLSPool(PoolingMethod): @@ -263,24 +246,15 @@ class CLSPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with CLS pooling" - - return hidden_states[0] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - first_token_flat_indices = torch.zeros_like(prompt_lens) - first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] - return hidden_states[first_token_flat_indices] + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with CLS pooling" + + return hidden_states[pooling_cursor.first_token_indices_gpu] class LastPool(PoolingMethod): @@ -288,20 +262,12 @@ class LastPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return hidden_states[-1] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 - return hidden_states[last_token_flat_indices] + return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): @@ -309,22 +275,19 @@ class AllPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with ALL pooling" - - return hidden_states - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - return list(hidden_states.split_with_sizes(prompt_lens.tolist())) + + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with ALL pooling" + + hidden_states_lst = list( + hidden_states.split( + pooling_cursor.num_scheduled_tokens_cpu.tolist())) + return [hidden_states_lst[i] for i in pooling_cursor.index] class MeanPool(PoolingMethod): @@ -332,31 +295,25 @@ class MeanPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( + def forward_all( self, hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + pooling_cursor: PoolingCursor, + ) -> Union[list[torch.Tensor], torch.Tensor]: + + assert not pooling_cursor.is_partial_prefill(), \ "partial prefill not supported with MEAN pooling" - return hidden_states.mean(dim=0, dtype=torch.float32) + prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device, + non_blocking=True) - def forward_all( - self, - hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, - ) -> Union[list[torch.Tensor], torch.Tensor]: # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) - start_indices = torch.cat([ - torch.tensor([0], device=hidden_states.device), - torch.cumsum(prompt_lens[:-1], dim=0) - ]) - end_indices = torch.cumsum(prompt_lens, dim=0) - return (cumsum[end_indices - 1] - cumsum[start_indices] + + start_indices = pooling_cursor.first_token_indices_gpu + end_indices = pooling_cursor.last_token_indices_gpu + return (cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) @@ -409,24 +366,37 @@ class PoolerNormalize(PoolerActivation): return x.to(pooled_data.dtype) -class PoolerClassify(PoolerActivation): +class PoolerMultiLabelClassify(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] - if num_labels < 2: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) - return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) +class PoolerClassify(PoolerActivation): -class PoolerScore(PoolerActivation): + def __init__(self, *, static_num_labels: bool = True) -> None: + super().__init__() + + if static_num_labels: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.num_labels = getattr(vllm_config.model_config.hf_config, + "num_labels", 0) + if self.num_labels == 0: + logger.warning("num_labels should be > 0 for classification" + "models, falling back to softmax. " + "Please check if the configuration is correct.") + else: + self.num_labels = None def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] + num_labels = (self.num_labels if self.num_labels is not None else + pooled_data.shape[-1]) + if num_labels < 2: return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) - return pooled_data + return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) class LambdaPoolerActivation(PoolerActivation): @@ -457,9 +427,33 @@ class EmbeddingPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerNormalize()) + # Load ST projector if available + from vllm.config import get_current_vllm_config + from vllm.model_executor.models.adapters import _load_st_projector + + vllm_config = get_current_vllm_config() + self.projector = _load_st_projector( + vllm_config.model_config) if vllm_config else None + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_dimension] + + # Apply ST projector + if self.projector is not None: + projector = cast(nn.Module, self.projector) + + def _proj(x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + y = projector(x.to(torch.float32)) + return y.to(orig_dtype) + + pooled_data = _proj(pooled_data) + # pooled_data shape: [batchsize, embedding_dimension] + pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation @@ -491,13 +485,14 @@ class EmbeddingPoolerHead(PoolerHead): for vecs, f in zip(pooled_data, flags) ] + # pooled_data shape: [batchsize, embedding_dimension] return pooled_data class RewardPoolerHead(PoolerHead): def __init__(self) -> None: - super().__init__(activation=PoolerClassify()) + super().__init__(activation=PoolerClassify(static_num_labels=False)) def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -651,15 +646,13 @@ class ClassifierPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_size] if self.classifier is not None: - # apply classifier once on the full batch if possible - if isinstance(pooled_data, torch.Tensor): - pooled_data = self.classifier(pooled_data) - elif len({data.shape for data in pooled_data}) <= 1: - pooled_data = self.classifier(torch.stack(pooled_data)) - else: - pooled_data = [self.classifier(data) for data in pooled_data] + pooled_data = self.classifier(pooled_data) + # pooled_data shape: [batchsize, num_labels] pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] @@ -672,6 +665,7 @@ class ClassifierPooler(Pooler): for vecs, f in zip(pooled_data, flags) ] + # scores shape: [batchsize, num_labels] return build_output(scores) @@ -702,12 +696,6 @@ class DispatchPooler(Pooler): ) -> PoolerOutput: poolers_by_task = self.poolers_by_task - if isinstance(hidden_states, list): - hidden_states_lst = hidden_states - else: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - hidden_states_lst = list(hidden_states.split(prompt_lens.tolist())) - outputs = list[PoolingSequenceGroupOutput]() offset = 0 for task, group in groupby(get_tasks(pooling_metadata)): @@ -718,7 +706,7 @@ class DispatchPooler(Pooler): num_items = len(list(group)) group_output: PoolerOutput = pooler( - hidden_states_lst[offset:offset + num_items], + hidden_states, pooling_metadata[offset:offset + num_items], ) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index f0531234e008ac9568ab81711cfe2b587c03161e..db5e4cd16c0d844b86285e018a9a74d9c05a456d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -15,7 +15,6 @@ QuantizationMethods = Literal[ "fbgemm_fp8", "modelopt", "modelopt_fp4", - "marlin", "bitblas", "gguf", "gptq_marlin_24", @@ -25,7 +24,6 @@ QuantizationMethods = Literal[ "gptq", "compressed-tensors", "bitsandbytes", - "qqq", "hqq", "experts_int8", "neuron_quant", @@ -37,6 +35,7 @@ QuantizationMethods = Literal[ "rtn", "inc", "mxfp4", + "petit_nvfp4", "blockwise_int8", "slimquant_w4a8", "slimquant_w4a8_marlin", @@ -109,13 +108,12 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .hqq_marlin import HQQMarlinConfig from .inc import INCConfig from .ipex_quant import IPEXConfig - from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig + from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config - from .qqq import QQQConfig from .rtn import RTNConfig from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig @@ -131,7 +129,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, - "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, @@ -142,7 +139,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "ptpc_fp8": PTPCFp8Config, - "qqq": QQQConfig, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, @@ -154,6 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "rtn": RTNConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, + "petit_nvfp4": PetitNvFp4Config, "blockwise_int8": BlockInt8Config, "slimquant_w4a8":SlimQuantW4A8Int8Config, "slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 946f7174bb298d5bdf264f3720dd568e3c5cfb0c..51451f75e560fe7ba30d49bb12f8a5b4d7e384b4 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -518,6 +518,7 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -547,6 +548,7 @@ class AWQMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, routed_scaling_factor=routed_scaling_factor, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index b7897a43793c73c22e8e0bdd7b9cf2c3782aedc1..9713757df9b077a9db3c776d13aa3ba7cfaaded5 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -466,6 +466,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -490,6 +491,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) if self.quant_config.load_in_8bit: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b7628919666608ac0ab13db6886f656f7e753be2..977a73f9965d075df8104482543741a8c69ae1b3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat, from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.transform import TransformConfig from pydantic import BaseModel import vllm.envs as envs @@ -29,10 +30,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod, get_linear_transform_schemes) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -67,6 +70,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_ignore_list: list[str], kv_cache_scheme: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None, + transform_config: Optional[TransformConfig] = None, ): super().__init__() self.ignore = ignore @@ -78,6 +82,12 @@ class CompressedTensorsConfig(QuantizationConfig): self.sparsity_ignore_list = sparsity_ignore_list self.config = config + if transform_config is not None: + self.transform_config = TransformConfig.model_validate( + transform_config) + else: + self.transform_config = None + def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -110,18 +120,26 @@ class CompressedTensorsConfig(QuantizationConfig): ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import - # Check if the layer is skipped for quantization. - # TODO (@robertgshaw2): support module names - if should_ignore_layer(prefix, - ignore=self.ignore, - fused_mapping=self.packed_modules_mapping): - return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod() if isinstance(layer, LinearBase): - scheme = self.get_scheme(layer=layer, layer_name=prefix) - if scheme is None: - return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod() - layer.scheme = scheme - return CompressedTensorsLinearMethod(self) + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + input_tfms, output_tfms = get_linear_transform_schemes( + layer, prefix, self.transform_config, + self.packed_modules_mapping) + + # choose quantization method + quant_method: LinearMethodBase = UnquantizedLinearMethod() + if quant_scheme is not None: + layer.scheme = quant_scheme + quant_method = CompressedTensorsLinearMethod(self) + + # choose transform method + if any((input_tfms, output_tfms)): + return CompressedTensorsLinearTransformMethod.from_schemes( + quant_method, input_tfms, output_tfms) + + else: + return quant_method if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): @@ -136,6 +154,7 @@ class CompressedTensorsConfig(QuantizationConfig): config=config) sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( config=config) + transform_config = config.get("transform_config") return cls( target_scheme_map=target_scheme_map, @@ -144,6 +163,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_scheme_map=sparsity_scheme_map, sparsity_ignore_list=sparsity_ignore_list, config=config, + transform_config=transform_config, ) @classmethod @@ -207,8 +227,10 @@ class CompressedTensorsConfig(QuantizationConfig): format ) if format is not None else is_activation_quantization_format( quant_format) - if act_quant_format: - input_activations = quant_config.get("input_activations") + # TODO(czhu): w4a8fp8 is in packed-quantized format + # but needs input activation quantization + input_activations = quant_config.get("input_activations") + if act_quant_format or input_activations: # The only case where we have activation quant supported # but no input_activations provided in the config # should be w8a16fp8 w8a16fp8 can also run for cases where @@ -359,6 +381,28 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation + def _is_fp8_w4a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + if not weight_quant or not input_quant: + return False + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + # Only per-group symmetric weight (4bit) + # + per-tok symmetric activation (8bit) quantization supported. + return (is_weight_4_bits and is_activation_8_bits and is_token + and is_symmetric and is_dynamic) + + def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w4a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) @@ -408,19 +452,30 @@ class CompressedTensorsConfig(QuantizationConfig): weight_quant: BaseModel, input_quant: BaseModel, format: Optional[str] = None) -> "CompressedTensorsScheme": + + # use the per-layer format if defined, otherwise, use global format + format = format if format is not None else self.quant_format + # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() + if self._is_fp8_w4a8_sm90(weight_quant, input_quant): + return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + if self._is_wNa16_group_channel(weight_quant, input_quant): - if (self.quant_format == CompressionFormat.marlin_24.value + if (format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, group_size=weight_quant.group_size) - if (self.quant_format == CompressionFormat.pack_quantized.value + if (format == CompressionFormat.pack_quantized.value and weight_quant.num_bits in WNA16_SUPPORTED_BITS): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, @@ -429,10 +484,7 @@ class CompressedTensorsConfig(QuantizationConfig): group_size=weight_quant.group_size, actorder=weight_quant.actorder) - act_quant_format = is_activation_quantization_format( - format - ) if format is not None else is_activation_quantization_format( - self.quant_format) + act_quant_format = is_activation_quantization_format(format) if act_quant_format: if self._is_fp4a4_nvfp4(weight_quant, input_quant): if cutlass_fp4_supported( @@ -512,9 +564,11 @@ class CompressedTensorsConfig(QuantizationConfig): # Find the "target" in the compressed-tensors config # that our layer conforms to. - # TODO (@robertgshaw): add compressed-tensors as dep - # so we do not have to re-write these functions - # need to make accelerate optional in ct to do this + # TODO (@kylesayrs): support ignore module names with ct matching utils + if should_ignore_layer(layer_name, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return None # Will be empty for models with only sparsity weight_quant = input_quant = None @@ -531,7 +585,7 @@ class CompressedTensorsConfig(QuantizationConfig): format = scheme_dict.get("format") # Find the sparsity scheme of the layer - # assume that fused layers inerhit first component's sparsity scheme + # assume that fused layers inherit first component's sparsity scheme sparsity_targets = (self.sparsity_scheme_map.keys() - set(self.sparsity_ignore_list)) sparsity_scheme: Optional[SparsityCompressionConfig] = None @@ -720,7 +774,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase): layer input. See LinearMethodBase for param details """ - scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 171f3742d8193802cd082ed0885857cc448dd05e..75864b03afea255954e06417c343b9c93e1ec07a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, @@ -66,12 +68,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - layer: torch.nn.Module, + layer: torch.nn.Module ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. - weight_quant = quant_config.target_scheme_map["Linear"].get("weights") - input_quant = quant_config.target_scheme_map["Linear"].get( + # Check if a using "Linear" to select schemes + if "Linear" in quant_config.target_scheme_map: + matched_target = "Linear" + else: + # May have instead defined the linear layers in the fused model + + fused_layers = [ + "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*" + ] + current_scheme = None + for fused_layer in fused_layers: + # Check if one of the fused layers are defined in quant_config + matched_target = find_matched_target( + layer_name=fused_layer, + module=layer, + targets=quant_config.target_scheme_map.keys(), + fused_mapping=quant_config.packed_modules_mapping) + + # Only valid if down_proj, gate_proj, and up_proj + # are mapped to the same quant scheme in the quant_config + if current_scheme is None: + current_scheme = quant_config.target_scheme_map.get( + matched_target) + else: + assert current_scheme == quant_config.target_scheme_map.get( + matched_target) + + weight_quant = quant_config.target_scheme_map[matched_target].get( + "weights") + input_quant = quant_config.target_scheme_map[matched_target].get( "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): @@ -247,13 +277,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): return # swizzle weight scales - layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w13_weight_scale), - requires_grad=False) + requires_grad=False) - layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w2_weight_scale), - requires_grad=False) + requires_grad=False) # w13 w13_input_global_scale = layer.w13_input_global_scale.max( @@ -293,6 +323,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return the appropriate GEMM experts implementation.""" experts = select_nvfp4_gemm_impl( @@ -320,6 +351,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -345,6 +377,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) @@ -384,8 +417,35 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + elif self.allow_flashinfer: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + return flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -401,8 +461,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -663,11 +723,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): pass + if self.use_cutlass: + device = layer.w13_weight.device + # ab_strides1 and c_strides2 are the same + self.ab_strides1_c_strides2 = torch.full( + (layer.local_num_experts, ), + layer.hidden_size, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full( + (layer.local_num_experts, ), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full( + (layer.local_num_experts, ), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, - ) -> FusedMoEPermuteExpertsUnpermute: + self, prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute: # cutlass path if self.use_cutlass: from vllm.model_executor.layers.fused_moe import ( @@ -687,6 +765,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, ) else: logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) @@ -694,6 +776,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, ) self.disable_expert_map = (num_dispatchers > 1 @@ -746,6 +832,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -773,9 +860,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - use_fused_gate=use_fused_gate, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, + use_fused_gate=use_fused_gate, ) # cutlass path @@ -821,6 +908,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map=None if self.disable_expert_map else expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) @@ -1015,6 +1106,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1046,9 +1138,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - use_fused_gate=use_fused_gate, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + use_fused_gate=use_fused_gate) return fused_experts( hidden_states=x, @@ -1325,6 +1417,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1354,6 +1447,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) @@ -1557,6 +1651,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1583,6 +1678,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 734fa603ba7b966f3ee8c4540d63dc3967aa790d..cac65cca5093fddef6df7cf189dc3a42407dfd27 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) @@ -21,5 +22,6 @@ __all__ = [ "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" + "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", + "CompressedTensorsW4A8Fp8" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 63bfe565b121182cb8656da4c86e856d5f136fb2..dedd681f15dedd56a5c9b2752b4945e341a8aeba 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 run_nvfp4_emulations) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + swizzle_blockscale) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): weight_loader=weight_loader) layer.register_parameter("input_global_scale", input_global_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer) -> None: global_input_scale = layer.input_global_scale.max().to(torch.float32) @@ -133,13 +112,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): torch.uint8), epilogue_tile_m).reshape( weight_scale.shape).view(torch.float8_e4m3fn)) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False) else: - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False) @@ -157,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): x=x, input_global_scale=layer.input_global_scale, weight=layer.weight_packed, - weight_scale_swizzled=layer.weight_scale_swizzled, + weight_scale_swizzled=layer.weight_scale, weight_global_scale=layer.weight_global_scale) if bias is not None: out = out + bias @@ -170,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) mm_args = (x_fp4, layer.weight_packed, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, output_dtype) + layer.weight_scale, layer.alpha, output_dtype) if self.backend == "flashinfer-trtllm": out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") elif self.backend == "flashinfer-cutlass": diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9827058803e47dc82ebf40569edfb36deba880 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +# yapf: enable +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A8Fp8"] +W4A8_SUPPORTED_TYPES_MAP = { + 4: scalar_types.int4, +} +W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size != 128 or self.strategy != "group": + raise ValueError("W4A8 kernels require group quantization " \ + "with group size 128") + + if num_bits not in W4A8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # hopper + return 90 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=torch.float8_e4m3fn, # always use fp8(e4m3) + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx, + out_type=params_dtype + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW4A8Fp8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + # TODO(czhu): allocate the packed fp8 scales memory here? + # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=torch.float8_e4m3fn, + ) + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + # per-channel scales + weight_chan_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("weight_chan_scale", weight_chan_scale) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc94b3c257e6a66391ceb785d03ff85c52d8233 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Generator +from itertools import accumulate +from typing import Callable, Optional + +import torch +from compressed_tensors.transform import (TransformArgs, TransformConfig, + TransformLocation, TransformScheme) +from compressed_tensors.utils import is_match + +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + LinearMethodBase, + QKVCrossParallelLinear) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 + HadamardTransform) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple) + + +class CompressedTensorsLinearTransformMethod(LinearMethodBase): + """ + Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds + input and output transforms to either side of the original apply method + """ + + @classmethod + def from_schemes( + cls, quant_method: LinearMethodBase, input_tfms: dict[int, + TransformTuple], + output_tfms: dict[int, TransformTuple] + ) -> "CompressedTensorsLinearTransformMethod": + assert input_tfms or output_tfms + + # TODO (@ksayers): implement QutlassLinearMethodNvFP4 + # hadacore and fwht can be selected by Transform module + + return cls(quant_method, input_tfms, output_tfms) + + def __init__(self, quant_method: LinearMethodBase, + input_tfms: dict[int, TransformTuple], + output_tfms: dict[int, TransformTuple]): + self.quant_method = quant_method + self.input_tfms = input_tfms + self.output_tfms = output_tfms + + self.input_transform: Optional[HadamardTransform] = None + self.output_transform: Optional[HadamardTransform] = None + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + # get weight loader for transforms + weight_loader: Callable = extra_weight_attrs.get( + "weight_loader") # type: ignore[assignment] + + # HACK: UnquantizedLinearMethod does not support weight loader v2, but + # transforms (specifically SharedWeightParameter) requires + # weight loader v2. Until UnquantizedLinearMethod supports v2, we must + # hack around this by getting weight loader v1 so ULM can load correctly + quant_method_name = self.quant_method.__class__.__name__ + if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED: + if isinstance(layer, QKVCrossParallelLinear): + weight_loader_v1 = layer.weight_loader_v1 + else: + weight_loader_v1 = layer.weight_loader + extra_weight_attrs["weight_loader"] = weight_loader_v1 + + self.quant_method.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + input_size=input_size, + output_size=output_size, + params_dtype=params_dtype, + **extra_weight_attrs) + + # validate schemes + num_partitions = len(output_partition_sizes) + self._validate_tfm_schemes(num_partitions) + + # create submodules for weight loading + if len(self.input_tfms) > 0: + scheme_name = list(self.input_tfms.values())[0].scheme_name + location = list(self.input_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform(self.input_tfms, layer, + weight_loader, + input_size_per_partition, + output_partition_sizes) + layer.register_module(transform_name, transform) + self.input_transform = transform + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform(self.output_tfms, layer, + weight_loader, + input_size_per_partition, + output_partition_sizes) + layer.register_module(transform_name, transform) + self.output_transform = transform + + # compute partition ranges for slicing activations + starts = [0] + list(accumulate(output_partition_sizes))[:-1] + self.partition_ranges = list(zip(starts, output_partition_sizes)) + + def process_weights_after_loading(self, layer): + self.quant_method.process_weights_after_loading(layer) + + for submodule in layer.children(): + if isinstance(submodule, HadamardTransform): + submodule.process_weights_after_loading() + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.input_transform is not None: + x = self.input_transform(x) + + assert bias is None + x = self.quant_method.apply(layer, x, bias) + + # TODO (@ksayers): Write a triton kernel to do this in parallel + if self.output_transform is not None: + for part_id, (start, length) in enumerate(self.partition_ranges): + x[:, start:start + length] = self.output_transform( + x[:, start:start + length], part_id=part_id) + + return x + + def _validate_tfm_schemes(self, num_partitions: int): + if len(self.input_tfms) > 0: + if 0 not in self.input_tfms: + raise ValueError("Must have same input") + + for part_index in range(num_partitions): + if self.input_tfms[part_index] != self.input_tfms[0]: + raise ValueError("Must have same input") + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + + for tfm in self.output_tfms.values(): + if tfm.scheme_name != scheme_name: + raise ValueError("Must have same scheme name") + if tfm.args.location != location: + raise ValueError("Must have same location") + + return self.input_tfms, self.output_tfms + + +def get_linear_transform_schemes( + layer: torch.nn.Module, layer_name: str, + transform_config: Optional[TransformConfig], + packed_modules_mapping: dict[str, list[str]] +) -> tuple[dict[int, TransformTuple], dict[ + int, TransformTuple]]: # [input_transform, [output_transform, ...]] + # there can only be one transform input scheme per (fused) module + input_tfms = {} + output_tfms = {} + + partition_names = get_layer_partition_names(layer_name, + packed_modules_mapping) + + for scheme_name, scheme, args in get_schemes_args(transform_config): + for part_index, part_name in enumerate(partition_names): + if is_match(part_name, layer, args.targets, + args.ignore) and args.is_online(): + if args.location == TransformLocation.INPUT: + input_tfms[part_index] = TransformTuple( + scheme_name, scheme, args) + + elif args.location == TransformLocation.OUTPUT: + output_tfms[part_index] = TransformTuple( + scheme_name, scheme, args) + + else: + raise ValueError(f"Cannot apply `{args.location}` " + f"transform to `{layer_name}`") + + return (input_tfms, output_tfms) + + +def get_schemes_args( + transform_config: Optional[TransformConfig] +) -> Generator[tuple[str, TransformScheme, TransformArgs]]: + if transform_config is None: + return + + for scheme_name, scheme in transform_config.config_groups.items(): + for args in scheme.apply: + yield (scheme_name, scheme, args) + + +def get_layer_partition_names( + layer_name: str, packed_modules_mapping: dict[str, + list[str]]) -> list[str]: + """ + Get all partition names associated with this layer. + Names are returned in order of their partition indices. + + ```python + mapping = {"gate_up_proj", "gate_proj", "up_proj"} + + assert get_layer_partition_names( + "mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"] + assert get_layer_partition_names( + "mlp.down_proj", mapping) == ["down_proj"] + """ + for fused_suffix, part_suffixes in packed_modules_mapping.items(): + if layer_name.endswith(fused_suffix): + return [ + layer_name.removesuffix(fused_suffix) + part_suffix + for part_suffix in part_suffixes + ] + + return [layer_name] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py new file mode 100644 index 0000000000000000000000000000000000000000..b3be25471773418bb1c643c1ac9f45acceefd152 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Hashable +from typing import Callable, Optional + +import torch +from compressed_tensors.transform import TransformLocation, TransformScheme +from torch import Tensor + +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parameter import SharedWeightParameter + + +class HadamardTransform(torch.nn.Module): + """ + Class which handles weight loading, postprocessing, and application of + transforms. Meant to be used with `CompressedTensorsLinearTransformMethod` + and attention transforms method (not implemented yet) + """ + transforms: dict[int, TransformTuple] # info parsed from transforms config + weight: SharedWeightParameter # container for shared tensors + + kernel: Callable # function used during application + scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0)) + + def __init__(self, + transforms: dict[int, TransformTuple], + layer: torch.nn.Module, + weight_loader: Callable, + input_size_per_partition: int, + output_partition_sizes: list[int], + kernel: Optional[Callable] = None): + super().__init__() + self.transforms = transforms + self.scales = {} + + if get_tensor_model_parallel_world_size() > 1: + raise NotImplementedError("Online transforms with tensor " + "parallelism is not supported") + + # Similar to row/col parallel params, but tensors are separate + # to allow for loading with shared memory + self.weight = SharedWeightParameter(weight_loader=weight_loader) + + # create shared partition data for each partition of the original weight + input_size = input_size_per_partition + for part_index, (_scheme_name, scheme, + args) in self.transforms.items(): + output_size = output_partition_sizes[part_index] + weight_size = self._get_weight_size(layer, args.location, + input_size, output_size) + + data_key = self._get_data_key(scheme, weight_size) + self.weight.add_partition( + part_index, + data_key, + size=(weight_size, weight_size), + dtype=scheme.precision, + ) + + # validate that shared tensors and schemes are correct + self._validate_input_transforms() + + # select kernel based on transform schemes + self.kernel = self._infer_kernel() if kernel is None else kernel + + def process_weights_after_loading(self): + for part_id in self.weight.partitions: + data = self.weight.partitions[part_id].data + + # required by torch.compile + self.weight.process_weights_after_loading() + + # precompute scale as a runtime multiply, not division + # do not fold into weight in order to utilize FWHT + self.scales[part_id] = 1 / math.sqrt(data.size(0)) + + # FUTURE: avoid runtime tranpose by processing weights + # prior to apply + + def forward(self, value: Tensor, part_id: int = 0) -> Tensor: + if part_id not in self.weight.partitions: + return value + + weight = self.weight.partitions[part_id] + weight = weight if self.transforms[ + part_id].args.inverse else weight.T # linear := x(W.T) + scale = self.scales[part_id] + return self.kernel(self, value.to(weight.dtype), weight, None).to( + value.dtype) * scale + + def _get_data_key(self, scheme: TransformScheme, + weight_size: int) -> Hashable: + return (id(scheme), weight_size) + + def _get_weight_size(self, layer: torch.nn.Module, + location: TransformLocation, input_size: int, + output_size: int) -> int: + if isinstance(layer, LinearBase): + if location == TransformLocation.INPUT: + return input_size + + elif location == TransformLocation.OUTPUT: + return output_size + + elif isinstance(layer, VocabParallelEmbedding): + if location == TransformLocation.INPUT: + return output_size + + elif location == TransformLocation.OUTPUT: + return input_size + + raise ValueError() + + def _validate_input_transforms(self): + assert len(self.transforms) > 0 + location = list(self.transforms.values())[0].args.location + + if location == TransformLocation.INPUT: + first_data = self.weight.partitions[0].data + for partition in self.weight.partitions.values(): + if partition.data.data_ptr() != first_data.data_ptr(): + raise ValueError("") + + def _infer_kernel(self) -> Callable: + # TODO (@ksayers): use fwht, hadacore + return dispatch_unquantized_gemm() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py new file mode 100644 index 0000000000000000000000000000000000000000..f42258f9f9d7f22321c54c512dbfb6e080901d0b --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod) + + +# Because qutlass fuses hadamard with quantization, it cannot automatically be +# composed with kernels in the way CompressedTensorsLinearTransformMethod does. +# Therefore, a separate scheme must be created for each quantized dtype +class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod): + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # fused hadamard quant linear method + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f353de1e6a741305ae5035e27033a7e1dfaba42 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import NamedTuple + +from compressed_tensors.transform import TransformArgs, TransformScheme + +__all__ = ["TransformTuple"] + + +class TransformTuple(NamedTuple): + scheme_name: str + scheme: TransformScheme + args: TransformArgs diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 3e43caa4cbf72bf965093f58154a3b8a0e5fbb91..2d8a684bc7d9006eb1fbda374960ceba306784e4 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -120,6 +120,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -146,6 +147,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f07be08554921abe201cdc49faa2056bd1ec6295..48bac8697e4662249421831d4c78420bb478f409 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -9,6 +9,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -23,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -44,8 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -145,7 +148,7 @@ class Fp8Config(QuantizationConfig): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self, layer.moe_config) + return Fp8MoEMethod(self, layer) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -219,8 +222,7 @@ class Fp8LinearMethod(LinearMethodBase): self.fp8_linear = Fp8LinearOp( act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, - cutlass_fp8_supported=cutlass_fp8_supported()) + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -372,6 +374,8 @@ class Fp8LinearMethod(LinearMethodBase): # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) + # layer.input_scale is None indicates dynamic quant and scale is + # computed from input. layer.input_scale = None # If checkpoint is fp8, handle that there are N scales for N @@ -422,7 +426,7 @@ class Fp8LinearMethod(LinearMethodBase): # On B200, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -482,16 +486,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig): - super().__init__(moe) + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) @@ -531,6 +539,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -678,7 +700,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale) - elif self.flashinfer_moe_enabled: + elif self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm w13_weight = swap_w13_to_w31(layer.w13_weight.data) @@ -686,9 +708,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data - if not self.block_quant: - register_moe_scaling_factors(layer) - rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -714,7 +733,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used(): + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ @@ -834,13 +853,24 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + if self.flashinfer_moe_backend is not None: + # NOTE: weights have to be swapped since the activation is + # applied on different half for flashinfer vs vllm + assert not self.block_quant + register_moe_scaling_factors(layer) + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + if self.flashinfer_moe_backend == \ + FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) + layer.w13_weight.data = w13_weight.data + if self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) @@ -867,6 +897,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) @@ -892,6 +923,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): per_act_token_quant=False, allow_deep_gemm=self.allow_deep_gemm, ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", @@ -917,6 +955,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -930,67 +969,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if not self.flashinfer_moe_enabled: - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_fused_experts) - return rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - use_fp8_w8a8=True, - apply_router_weight_on_input=apply_router_weight_on_input, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - expert_map=expert_map) - elif self.use_marlin: - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") - return torch.ops.vllm.fused_marlin_moe( - x, - layer.w13_weight, - layer.w2_weight, - None, - None, - layer.w13_weight_scale, - layer.w2_weight_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=scalar_types.float8_e4m3fn.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map) - elif self.flashinfer_moe_enabled: - assert activation == 'silu' - assert scoring_func == 'sigmoid' + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") if self.block_quant: assert (renormalize and use_grouped_topk and custom_routing_function is None) @@ -1011,7 +995,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, + routed_scaling=routed_scaling_factor, ) else: assert (not renormalize @@ -1026,28 +1010,99 @@ class Fp8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, topk_group=topk_group, apply_router_weight_on_input=apply_router_weight_on_input) - elif self.fused_experts is not None: - return self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_fused_experts) + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, activation=activation, - global_num_experts=global_num_experts, + use_fp8_w8a8=True, apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, w1_scale=(layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - ) + block_shape=self.quant_config.weight_block_size, + expert_map=expert_map) + elif self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + None, + None, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert self.block_quant is None + assert (not renormalize and custom_routing_function is not None) + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) else: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( + common_kwargs = dict( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1064,11 +1119,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm)) + ) + + if self.fused_experts is not None: + return self.fused_experts(**common_kwargs) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + **common_kwargs, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm), + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 49d28927d6e7434e3e84ad39892eaa36e109f14b..ad648df23819444537a7ca4f224798951b4d83bf 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -13,7 +13,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -28,8 +29,10 @@ logger = init_logger(__name__) class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" - def __init__(self, ) -> None: + def __init__(self, + unquantized_modules: Optional[list[str]] = None) -> None: super().__init__() + self.unquantized_modules = unquantized_modules or [] def __repr__(self) -> str: return ("GGUFConfig()") @@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): + if is_layer_skipped_gguf(prefix, self.unquantized_modules): + return UnquantizedLinearMethod() return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): return GGUFEmbeddingMethod(self) @@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig): return None +def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]): + return any(module_name in prefix for module_name in unquantized_modules) + + UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} STANDARD_QUANT_TYPES = { WeightType.Q4_0, @@ -523,6 +532,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -553,6 +563,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f18c936bac60535f4e7a2ff7a90b0086bd0941ef..2272709f93091aff20eeeaa35eb74231e358d97a 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -37,6 +37,7 @@ class GPTQConfig(QuantizationConfig): desc_act: bool, lm_head_quantized: bool, dynamic: dict[str, dict[str, Union[int, bool]]], + autoround_version: str = "", ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -74,6 +75,9 @@ class GPTQConfig(QuantizationConfig): "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {self.weight_bits} bits.") + # used to identify GPTQ model quantized by autoround + self.autoround_version = autoround_version + def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " @@ -108,8 +112,10 @@ class GPTQConfig(QuantizationConfig): desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + autoround_version = cls.get_from_keys_or(config, ["autoround_version"], + default="") return cls(weight_bits, group_size, desc_act, lm_head_quantized, - dynamic) + dynamic, autoround_version) def get_quant_method( self, layer: torch.nn.Module, prefix: str diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c5d1e017014f3f51bd1aff63dea84a4d8d57029e..3644d91f64e3c26ba9221da81387480901bcc87f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -119,6 +119,9 @@ class GPTQMarlinConfig(QuantizationConfig): self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + # used to identify GPTQ model quantized by autoround + self.autoround_version = full_config.get("autoround_version", "") + def __repr__(self) -> str: return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " @@ -643,6 +646,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -669,6 +673,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 07ecc096231a427942899117c428c928301170c4..1280f5f1eadf7b894506d5bb4ad69b21cc688544 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -20,6 +20,7 @@ class MPLinearLayerConfig: group_size: int zero_points: bool has_g_idx: bool + out_type: Optional[torch.dtype] = None class MPLinearKernel(ABC): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index a5084f6ee92cddbe0b47e6a20a9218c268f00832..4bcfcd04b3d8b0021c85579bbbfa1765fb9f6ed8 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 ConchLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501 + CutlassW4A8LinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 Dynamic4bitLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 @@ -24,6 +26,7 @@ from vllm.platforms import current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ + CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..9e23c0dd3595b037ea1e74c1a26458d8302d4efe --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class CutlassW4A8LinearKernel(MPLinearKernel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # dynamic per-tok fp8 activation quantization + self.quant_fp8 = QuantFP8(static=False, + group_shape=GroupShape.PER_TOKEN) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cuda(): + return False, "CUTLASS only supported on CUDA" + + if not current_platform.is_device_capability(90): + return False, "CUTLASS W4A8 requires compute capability of 90 "\ + "(Hopper)" + + if c.act_type != torch.float8_e4m3fn: + return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations" + + if c.has_g_idx: + return False, "Act reordering not supported by CUTLASS W4A8" + + if c.zero_points: + return False, "Zero points not supported by CUTLASS W4A8" + + if c.weight_type != scalar_types.int4: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "CUTLASS W4A8, only supported int4" + + # TODO(czhu): support -1 (column-wise) + if c.group_size != 128: + return False, "Only group_size 128 is supported" + + in_features, out_features = c.partition_weight_shape + if in_features % 128 or out_features % 128: + return False, "K and N must be divisible by 128, got "\ + f"{c.partition_weight_shape}" + + if c.out_type != torch.bfloat16: + return False, "Only bfloat16 output type currently supported"\ + f"got {c.out_type=}" + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + # TODO(czhu): optimize speed/mem usage + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.cutlass_encode_and_reorder_int4b( + x.data.t().contiguous().t()) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous().to(torch.float8_e4m3fn) + x.data = ops.cutlass_pack_scale_fp8(x.data) + return x + + # Encode/reorder weights and pack scales + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + self._transform_param(layer, "weight_chan_scale", lambda x: x) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + w_ch_s = layer.weight_chan_scale + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + x_2d, act_scales = self.quant_fp8(x_2d) + output = ops.cutlass_w4a8_mm(a=x_2d, + b_q=w_q, + b_group_scales=w_s, + b_group_size=c.group_size, + a_token_scales=act_scales, + b_channel_scales=w_ch_s) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 18f5ce04fd355e43a7b1727c94d38ab66f11dcb0..2bc68ab3ebd186771710a9a808f166769473244a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -6,6 +6,8 @@ from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( + CPUScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 @@ -18,7 +20,7 @@ from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { - PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..59d2b5bce962ecb018313270bd8a72fd5390e5fe --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.model_executor.layers.utils import check_cpu_sgl_kernel +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class CPUScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cpu(): + return False, "CPUScaledMM requires running on CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = getattr(layer, self.w_q_name) + dtype = weight.dtype + N, K = weight.size() + if (current_platform.get_cpu_architecture() == CpuArchEnum.X86 + and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric + and check_cpu_sgl_kernel(N, K, dtype)): + self.linear_method = self._apply_weights_sgl + self.process_weights_for_sgl(layer) + else: + self.linear_method = self._apply_weights_onednn + self.process_weights_for_onednn(layer) + + def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Transpose to [K, N] for convenience + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # oneDNN kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + azp = (int8_traits.min - + range_min / scale).round().to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # Different from cutlass, oneDNN kernels only need the AZP adjustment + # term for dynamic quantization. And s_b should be folded into the + # term. Such as: + # s_a * s_b * [(A - zp_a)B] + bias = + # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = + # s_a * GEMM_output - s_a * zp_a * adj + bias + if not (self.config.input_symmetric + and self.config.is_static_input_scheme): + weight = getattr(layer, self.w_q_name) + weight_scale = getattr(layer, self.w_s_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) + azp_adj = azp_adj * weight_scale.squeeze() + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + weight = getattr(layer, self.w_q_name) + self.dnnl_handler = ops.create_onednn_scaled_mm( + weight, + getattr(layer, self.w_s_name), + torch.get_default_dtype(), + getattr(layer, self.i_s_name) is None, + not self.config.input_symmetric, + 32, + ) + # weight is prepacked and maintained by the dnnl_handler, + # release the original weight + setattr(layer, self.w_q_name, None) + del weight + + def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + # WEIGHT + weight = getattr(layer, self.w_q_name) + packed_weight = torch.ops._C.convert_weight_packed(weight) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(packed_weight, requires_grad=False)) + + if layer.bias is not None: + bias = layer.bias + layer.register_parameter( + "bias_fp32", + torch.nn.Parameter(bias.float().data, requires_grad=False)) + + # WEIGHT SCALE + # CPU SGL kernels only support per-channel. + # For per-tensor quant, convert to the per-channel case. + weight_scale = getattr(layer, self.w_s_name) + if not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.linear_method( + layer, + x, + bias, + ) + + def _apply_weights_onednn( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + x_q, x_s, x_zp = ops.onednn_scaled_int8_quant( + x, i_s, i_zp, self.config.input_symmetric) + + m = x.size(0) + n = self.dnnl_handler.n + out = torch.empty((m, n), dtype=x.dtype) + ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, + bias) + + return out + + def _apply_weights_sgl( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + return torch.ops._C.int8_scaled_mm_with_quant( + x, + w_q, + w_s, + layer.bias_fp32 if bias is not None else None, + x.dtype, + True, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index b865c5fd7b61724f926009e99c8dd942aca037af..757a4f3772fb1a4f543db759f53f029b13d57923 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -27,8 +27,8 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): def can_implement( cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - if (not current_platform.is_cuda() and not current_platform.is_cpu()): - return False, "CutlassScaledMM requires running on CUDA or CPU." + if not current_platform.is_cuda(): + return False, "CutlassScaledMM requires running on CUDA." return True, None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py deleted file mode 100644 index 18d1c13373df90fd681d242fb75cf77ef783ca67..0000000000000000000000000000000000000000 --- a/vllm/model_executor/layers/quantization/marlin.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - - -class MarlinConfig(QuantizationConfig): - """Config class for Marlin. - - Reference: https://github.com/IST-DASLab/marlin/tree/master - """ - - def __init__( - self, - group_size: int, - lm_head_quantized: bool, - ) -> None: - super().__init__() - - # Group size for the quantization. - self.group_size = group_size - self.lm_head_quantized = lm_head_quantized - if self.group_size != 128 and self.group_size != -1: - raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin, but got group_size of " - f"{self.group_size}") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // 4 - - # Tile size used by marlin kernels. - self.tile_size = 16 - - # Min out_features dim - self.min_n_threads = 64 - - # Min in_features dim - self.min_k_threads = 128 - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = 16 - - # Permutation length used by the marlin kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return (f"MarlinConfig(group_size={self.group_size}, " - f"lm_head_quantized={self.lm_head_quantized})") - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "marlin" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - # Need to figure it out - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": - group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(group_size, lm_head_quantized) - - @classmethod - def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" - or hf_quant_cfg.get("is_marlin_format", False)) - - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "marlin") - - if is_marlin_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. Using {} kernel.". - format(cls.get_name(), cls.get_name())) - logger.info(msg) - return cls.get_name() - - return None - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["MarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): - return MarlinLinearMethod(self) - return None - - -class MarlinLinearMethod(LinearMethodBase): - """Linear method for Marlin. - - Args: - quant_config: The Marlin quantization config. - """ - - def __init__(self, quant_config: MarlinConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - del output_size # Unused. - weight_loader = extra_weight_attrs["weight_loader"] - - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - # Determine if channelwise or not - input_groups = (1 if self.quant_config.group_size == -1 else - input_size_per_partition // - self.quant_config.group_size) - - weight_scale_args = { - "data": - torch.empty( - input_groups, - output_size_per_partition, - device="cuda", - dtype=params_dtype, - ), - "weight_loader": - weight_loader - } - if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) - else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s", scales) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s = Parameter(layer.s.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - scales = layer.s - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = scales.shape[1] - - output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, - size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e0f462b36976fb6bf81a14e8f852acf09242f50a..4bb8438d90844fe56db1b7f46077e7440498934e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum from typing import Any, Callable, Optional, Union import torch @@ -27,8 +26,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -49,11 +51,6 @@ QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] -class FlashinferMoeBackend(Enum): - TENSORRT_LLM = "TensorRT-LLM" - CUTLASS = "CUTLASS" - - class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" @@ -179,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptFp8MoEMethod(self, layer.moe_config) + return ModelOptFp8MoEMethod(self, layer) return None @@ -278,18 +275,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptFp8Config, - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> None: - super().__init__(moe) + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) + + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.fused_experts is not None or \ + self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def create_weights( self, @@ -433,11 +462,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), requires_grad=False) - if self.flashinfer_moe_enabled: + if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) register_moe_scaling_factors(layer) + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) def apply( self, @@ -453,6 +483,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -461,14 +492,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") - if self.flashinfer_moe_enabled: - assert activation == 'silu' + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") assert not renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, @@ -492,9 +522,40 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) + + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert not renormalize + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) return fused_experts( @@ -826,13 +887,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, requires_grad=False) + # Calculate `1 / input_scale` so that we don't need to do so at runtime + layer.input_scale_inv = Parameter( + (1 / layer.input_scale).to(torch.float32), requires_grad=False) + # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Block scale must be represented as FP8-E4M3") - if self.backend == "flashinfer-trtllm": + if self.backend == "marlin": + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + elif self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -849,21 +918,14 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): torch.uint8), epilogue_tile_m).reshape( weight_scale.shape).view(torch.float8_e4m3fn)) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) - if self.backend == "marlin": - prepare_fp4_layer_for_marlin(layer) - del layer.alpha - del layer.input_scale - del layer.weight_scale_swizzled - def apply( self, layer: torch.nn.Module, @@ -885,22 +947,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - s_quant = 1 / layer.input_scale - x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale assert (x_fp4.dtype == torch.uint8) assert (layer.weight.dtype == torch.uint8) assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) + assert (layer.weight_scale.dtype == torch.float8_e4m3fn) assert (layer.alpha.dtype == torch.float32) mm_args = ( x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, + layer.weight_scale, layer.alpha, output_dtype, ) @@ -951,42 +1012,32 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.flashinfer_moe_backend = None if self.allow_flashinfer: - flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_moe_backend == "throughput": - self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - logger.info_once("Using FlashInfer CUTLASS kernels for " - "ModelOptNvFp4FusedMoE.") - elif flashinfer_moe_backend == "latency": - self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM - logger.info_once("Using FlashInfer TensorRT-LLM kernels for " - "ModelOptNvFp4FusedMoE.") - else: - allowed_backends = ["throughput", "latency"] - raise ValueError( - f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" - f" expected one of {allowed_backends}") - - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore[assignment] + self.flashinfer_moe_backend = get_flashinfer_moe_backend() + logger.info_once( + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + " for ModelOptNvFp4FusedMoE.") def maybe_make_prepare_finalize( self, moe: FusedMoEConfig, ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if not self.allow_flashinfer: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize + if (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.CUTLASS): + prepare_finalize = ( + build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + )) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + return super().maybe_make_prepare_finalize(moe) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: experts = select_nvfp4_gemm_impl( moe, @@ -1265,6 +1316,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): del layer.w2_weight_scale del layer.w13_weight del layer.w13_weight_scale + elif self.use_marlin: + # Marlin processing + prepare_moe_fp4_layer_for_marlin(layer) + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant else: # Non-TRT-LLM processing (Cutlass or non-flashinfer) assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( @@ -1273,28 +1331,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "Weight Blockscale must be represented as FP8-E4M3") w13_blockscale_swizzled = swizzle_blockscale( layer.w13_weight_scale) - layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, - requires_grad=False) + layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, + requires_grad=False) assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( "Expected weight_scale.dim(1) to be divisible by 16") assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Blockscale must be represented as FP8-E4M3") w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) - layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, - requires_grad=False) + layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, + requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) - if self.use_marlin: - prepare_moe_fp4_layer_for_marlin(layer) - del layer.g1_alphas - del layer.g2_alphas - del layer.w13_input_scale_quant - del layer.w2_input_scale_quant - del layer.w13_blockscale_swizzled - del layer.w2_blockscale_swizzled - def apply( self, layer: torch.nn.Module, @@ -1309,6 +1358,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1387,6 +1437,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) @@ -1409,7 +1460,52 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_num_experts=global_num_experts, expert_map=expert_map) - if self.fused_experts is None: + if self.fused_experts is not None: + assert self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif (self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + out = flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import ( @@ -1418,8 +1514,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -1432,27 +1528,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e=layer.w13_weight.shape[0], expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - else: - assert self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") - - out = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, - apply_router_weight_on_input=apply_router_weight_on_input, - ) return out diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 9cc98662c1763b26420bc22d1059dd7e16f1ccf8..2311954f51082a169770060e2394e602faffbd17 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig): awq_min_capability = AWQConfig.get_min_capability() gptq_compatible = quant_method == "gptq" and \ - not desc_act and num_bits in [4, 8] + not desc_act and num_bits in [4, 8] awq_compatible = quant_method == "awq" and num_bits == 4 and \ device_capability >= awq_min_capability @@ -182,11 +182,8 @@ class MoeWNA16Method(FusedMoEMethodBase): quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __init__( - self, - quant_config: MoeWNA16Config, - moe: FusedMoEConfig, - ): + def __init__(self, quant_config: MoeWNA16Config, + moe: "FusedMoEConfig") -> None: super().__init__(moe) self.quant_config = quant_config self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1' @@ -199,6 +196,7 @@ class MoeWNA16Method(FusedMoEMethodBase): hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + self.moe = layer layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size @@ -341,6 +339,7 @@ class MoeWNA16Method(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -353,7 +352,6 @@ class MoeWNA16Method(FusedMoEMethodBase): use_fused_gate: Optional[bool] = False, ) -> torch.Tensor: assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `MoeWNA16Method` yet.") @@ -369,6 +367,7 @@ class MoeWNA16Method(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, routed_scaling_factor=routed_scaling_factor, @@ -490,13 +489,14 @@ class MoeWNA16Method(FusedMoEMethodBase): def moe_wna16_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, shard_id: str, + weight_name: str, + shard_id: str, expert_id: int, return_success: bool = False): if "g_idx" in weight_name: - return + return False if return_success else None if not layer.quant_config.has_zp and "qzeros" in weight_name: - return + return False if return_success else None device = get_tp_group().device tp_rank = get_tensor_model_parallel_rank() @@ -542,12 +542,18 @@ class MoeWNA16Method(FusedMoEMethodBase): param.data[expert_id, :shard_size // 2] = tensor else: param.data[expert_id, shard_size // 2:] = tensor + return True if return_success else None elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + return True if return_success else None else: - weight_loader(param, loaded_weight, weight_name, shard_id, - expert_id) - return_success = True - return return_success - return moe_wna16_weight_loader \ No newline at end of file + # Delegate to the original loader, passing return_success + return weight_loader(param, + loaded_weight, + weight_name, + shard_id, + expert_id, + return_success=return_success) + + return moe_wna16_weight_loader diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6a190ebbc063e5f37a0e34a5c5997c7e53f3212e..91db47be6c1ca7f46ea27b65aff482ac06d589b0 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,11 +6,13 @@ import torch from torch.nn.parameter import Parameter from vllm import envs + +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) -from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - triton_kernel_moe_forward) +from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -113,6 +115,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.topk_indices_dtype = None self.moe = moe self.use_marlin = self._should_use_marlin() + self.max_capture_size = get_current_vllm_config( + ).compilation_config.max_capture_size + + if current_platform.is_device_capability(100) and not has_flashinfer(): + logger.warning_once( + "MXFP4 MoE is enabled on Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") if current_platform.is_device_capability(100) and not has_flashinfer(): logger.warning_once( @@ -444,6 +454,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return tile_tokens_dim + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + if (prepare_finalize.activation_format == + mk.FusedMoEActivationFormat.BatchedExperts): + raise NotImplementedError( + "Mxfp4 does not support batched experts format for EP") + else: + if should_use_flashinfer_mxfp4(): + # B200 code-path + kwargs = { + "gemm1_alpha": layer.gemm1_alpha, + "gemm1_beta": layer.gemm1_beta, + "gemm1_clamp_limit": layer.gemm1_clamp_limit, + "w13_bias": layer.w13_bias, + "w2_bias": layer.w2_bias, + "max_capture_size": self.max_capture_size, + } + return TrtLlmGenExperts(moe, **kwargs) + else: + # Use matmul_ogs from triton_kernels here! + raise NotImplementedError( + "Mxfp4 does not support non-batched experts format for EP") + + def _route_and_experts( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None + ) -> torch.Tensor: + + assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + def apply( self, layer: torch.nn.Module, @@ -458,6 +553,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -481,6 +577,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( @@ -502,6 +599,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation=activation, expert_map=expert_map) + if self.fused_experts is not None: + return self._route_and_experts( + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ) + assert _can_support_mxfp4( use_grouped_topk, topk_group, num_expert_group, expert_map, custom_routing_function, e_score_correction_bias, @@ -512,15 +632,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if should_use_flashinfer_mxfp4(): from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe - assert not self.moe.use_ep, ( - "EP is not supported for flashinfer mxfp4 moe backend yet.") if _should_use_flashinfer_mxfp4_bf16(): assert x.dtype == torch.bfloat16 x_quant = x x_scale = None else: x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *x.shape[:-1], -1) trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -538,20 +657,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): None, # output1_scale_scalar None, # output1_scale_gate_scalar None, # output2_scale_scalar - self.num_experts, + global_num_experts, top_k, None, # n_group None, # topk_group self.intermediate_size, # padded to multiple of 256 - 0, # local_expert_offset + layer.ep_rank * layer.local_num_experts, # local_expert_offset self.num_experts, # local num experts None, self._get_tile_tokens_dim(x, top_k), 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize + tune_max_num_tokens=self.max_capture_size, )[0] return trtllm_gen_output else: + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward) return triton_kernel_moe_forward( hidden_states=x, w1=self.w13_weight_triton_tensor, diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9fee69bb02119940b19e7b99cb3d051bdcf0d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/petit.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + +from typing import Any, Optional + +import regex as re +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.petit_utils import ( + apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +# Initialize logger for the module +logger = init_logger(__name__) + + +# Configuration class to support the NVFP4 quantized model +# generated by the ModelOpt quantization tool +class PetitNvFp4Config(QuantizationConfig): + """Config class for Petit FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: Optional[str] = None, + group_size: Optional[int] = None, + exclude_modules: Optional[list[str]] = None, + ) -> None: + self._check_hardware_support() + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning("Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change.") + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + def _check_hardware_support(self) -> None: + """ + Verifies that the current hardware is supported by the Petit backend. + This backend is specifically designed for AMD GPUs and is not + supported on the CUDA platform. + """ + # This check ensures the code is NOT running on an NVIDIA GPU. + if current_platform.is_cuda(): + raise ValueError( + "The 'petit' quantization backend is designed for AMD GPUs " + "and is not supported on the CUDA platform. For NVIDIA GPUs, " + "please use a different quantization method such as FP8, AWQ, " + "or GPTQ.") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "petit_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Petit supports the gfx90a and gfx942 GPUs + return 90 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": + qc = cls.get_from_keys(config, ["quantization"]) + + quant_method_raw = qc.get("quant_algo") + if not isinstance(quant_method_raw, str) or not quant_method_raw: + raise ValueError( + "Missing or invalid 'quant_algo' in quantization config.") + quant_method = quant_method_raw.upper() + + group_size_raw = qc.get("group_size") + if not isinstance(group_size_raw, int): + raise ValueError( + "Missing or invalid 'group_size' (int) in hf_quant_config.json." + ) + group_size = group_size_raw + + verify_petit_nvfp4_supported(quant_method, group_size) + + kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto" + if not isinstance(kv_cache_quant_algo_raw, str): + raise ValueError( + "'kv_cache_quant_algo' must be a string if provided.") + kv_cache_quant_algo = kv_cache_quant_algo_raw + + exclude_raw = qc.get("exclude_modules", []) + if exclude_raw is None: + exclude_modules: list[str] = [] + elif isinstance(exclude_raw, list) and all( + isinstance(x, str) for x in exclude_raw): + exclude_modules = exclude_raw + else: + raise ValueError( + "'exclude_modules' must be a list[str] (or omitted).") + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + + return cls( + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo=kv_cache_quant_algo, + group_size=group_size, + exclude_modules=exclude_modules, + ) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + if not current_platform.is_rocm(): + return None + + qc = hf_quant_cfg.get("quantization", hf_quant_cfg) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"): + return cls.get_name() # "petit_nvfp4" + return None + + @classmethod + def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool: + qc = quant_config.get("quantization", quant_config) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + return algo == "NVFP4" + + def is_layer_excluded(self, prefix: str, + exclude_modules: list[str]) -> bool: + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + exclude = self.require_exclude_modules() + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, exclude) or self.is_layer_excluded( + prefix, exclude): + return UnquantizedLinearMethod() + return PetitNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return PetitFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> list[str]: + return [] + + def require_group_size(self) -> int: + if self.group_size is None: + logger.warning("group_size not set; defaulting to 16 for NVFP4.") + return 16 + return self.group_size + + def require_kv_cache_quant_algo(self) -> str: + return self.kv_cache_quant_algo or "auto" + + def require_exclude_modules(self) -> list[str]: + return list(self.exclude_modules or []) + + +class PetitFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + super().__init__(quant_config) + + +class PetitNvFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + group_size = self.quant_config.require_group_size() + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + prepare_nvfp4_layer_for_petit(layer) + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_petit_nvfp4_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index d11cba2caba880f44574a923579503ac3b4b3510..466fd5fba768542806c161f278a5522f7158e9a4 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -97,8 +97,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( act_quant_static=False, - cutlass_fp8_supported=False, - act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_group_shape=GroupShape.PER_TOKEN, + force_fp8_e4m3fnuz=True) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py deleted file mode 100644 index 25978cb13b3ab6b6aedc2b5f979b4763d35e75dd..0000000000000000000000000000000000000000 --- a/vllm/model_executor/layers/quantization/qqq.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - - -class QQQConfig(QuantizationConfig): - """Config class for QQQ - - Reference: https://arxiv.org/pdf/2406.09904 - """ - - def __init__( - self, - weight_bits: int, - group_size: int, - is_sym: bool = True, - ) -> None: - super().__init__() - self.weight_bits = weight_bits - self.group_size = group_size - self.is_sym = is_sym - - # Verify - if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: - raise ValueError( - f"QQQ does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " - "are supported.") - if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"QQQ does not support group_size = {self.group_size}. " - f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " - "are supported.") - if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: - raise ValueError( - f"QQQ does not support is_sym = {self.is_sym}. " - f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.weight_bits - - # Tile size used by QQQ kernels. - self.tile_size = MARLIN_QQQ_TILE - - # Min out_features dim - self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N - - # Min in_features dim - self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = MARLIN_QQQ_MAX_PARALLEL - - # Permutation length used by the QQQ kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return "QQQConfig(weight_bits={}, group_size={})".format( - self.weight_bits, self.group_size) - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "qqq" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - """List of filenames to search for in the model directory.""" - return [ - "quant_config.json", - "quantize_config.json", - ] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "QQQConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - return cls(weight_bits, group_size) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QQQLinearMethod"]: - if isinstance(layer, LinearBase): - return QQQLinearMethod(self) - return None - - -class QQQLinearMethod(LinearMethodBase): - """Linear method for QQQ. - - Args: - quant_config: The QQQ quantization config. - """ - - def __init__(self, quant_config: QQQConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs["weight_loader"] - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - s_channel = ChannelQuantScaleParameter(data=torch.empty( - 1, - output_size_per_partition, - device="cuda", - dtype=torch.float, - ), - weight_loader=weight_loader, - output_dim=1) - - if self.quant_config.group_size == -1: - s_group_data = torch.tensor( - [], - device="cuda", - dtype=torch.half, - ) - else: - s_group_data = torch.empty( - input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, - device="cuda", - dtype=torch.half, - ) - - s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} - - if self.quant_config.group_size == -1: - s_group = BasevLLMParameter(**s_group_attr) - else: - s_group = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **s_group_attr) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s_channel", s_channel) - layer.register_parameter("s_group", s_group) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) - layer.s_group = Parameter(layer.s_group.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - s_ch = layer.s_channel - s_group = layer.s_group - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = s_ch.shape[1] - - x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) - - output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 58f56c6381b317a949839079b8308241e82aded3..fdf03ded044807c22d5f1417fe875a3f5d40de8b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -218,6 +218,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -244,6 +245,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) @@ -380,6 +382,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -406,6 +409,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 8bdb50e07b1373308a99ff7afbb24af5c4046c55..8f72b8cbea7a75c133fcb7fe36d2f3a5fcdc6a1e 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -283,6 +283,7 @@ class RTNMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -309,6 +310,7 @@ class RTNMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..0ea0225c96af134518767e7b0554eb0402e0f454 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..be487f2805b8597636e5551ce217dbf4ba08b0bd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..f81e09e198c869e5817f1bf64166b171dffdaacb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..e073843af64c5658f47ba6afc0ea5e21b10ff301 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..f74a52fc17c9d085d11be052118f319a405e9007 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..8cab1b093276a6f3e9023fc2644401838f7baf71 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 1c61451fb34e52deec827f8f63c80fb15830c202..ae244f90bb06451aeb87e878cb9d2adc6abc3dfc 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,73 +1,73 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 4 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -83,7 +83,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -115,15 +115,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -133,13 +133,13 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 63e661c80de6a7b1422f7a994a2ee7a4b724911c..b2931d68f488aa758b3f40ba010b2e5032d0a476 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,83 +1,83 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 4 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -99,9 +99,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -139,8 +139,8 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } -} +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 56b939e52fac3ed53a4e0ba640c40010cb3af30a..ad630f0d787cf7fe86ab5289514ae3d5c2bd8aa3 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,30 +1,30 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 4 + "num_stages": 3 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -32,19 +32,19 @@ "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, - "num_warps": 4, + "num_warps": 8, "num_stages": 3 }, "32": { @@ -59,9 +59,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -83,7 +83,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 63d9a0bf5d79ddaaad547d44338ad4b959ad72b1..10b940c04fad358db81e0360184442697e198746 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,50 +1,50 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 8, + "num_warps": 4, "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 3 }, "32": { @@ -59,15 +59,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 7fa398c15a2a535401709b0f25e20f6e4b23e58e..94ce6e77f09cec01467c6d4a42c94d2a0815d142 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,55 +1,55 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, @@ -59,31 +59,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,7 +99,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -131,7 +131,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index f15d8f64c7090bd71d0091a524c65d7818fec38e..9540df407975e1ec33afa925967004fce5586728 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,57 +1,57 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -59,33 +59,33 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -93,23 +93,23 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..96f6c307b357dcae4a6be50a82ed835d31326d47 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000000000000000000000000000000..567675787d4f9dda9487213414cb78547fb35369 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 51e237b91b8e775a36bcf783c078c2c1cecbcbd2..0894ff2fa33224770ddf37ca123c4008cc9a742c 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,6 +1,6 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -8,55 +8,55 @@ "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "24": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "48": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, @@ -64,74 +64,74 @@ "num_stages": 4 }, "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "96": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "2048": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -139,8 +139,8 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 6280219c9ee7d26f7e2fd3625dc92d847ddc7982..86c68e08a1a6a6f2cabf994df45a71a0a8d4a69c 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,78 +1,78 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "96": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, @@ -80,58 +80,58 @@ "num_stages": 5 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "256": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 0a1e14cffbb2a894a701352193947d272427db0d..af1a384cbcbd3b7c674e98b40a62b9230944aa30 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,14 +1,14 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, @@ -16,26 +16,26 @@ "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, @@ -43,9 +43,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, @@ -59,7 +59,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, @@ -67,31 +67,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -101,31 +101,31 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, @@ -141,6 +141,6 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 15b1c93f60fc5068ba11b82b6d5924dd2024a824..d381764a2641445760db7e0372c28c5477d8c53b 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,22 +1,22 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, - "num_warps": 4, + "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -24,18 +24,18 @@ "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, @@ -45,47 +45,47 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -93,29 +93,29 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 8ff12e64c172f5a5d0fbdf900728fe60b33877e2..821ad0c70457318b2e173c7d5974d8cbeeec0029 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,43 +1,43 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, @@ -45,7 +45,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, @@ -59,7 +59,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, @@ -73,19 +73,19 @@ }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -99,21 +99,21 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, @@ -123,9 +123,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 4532f93681e2be175b1bf94f81bfde711821cd60..daaf21c286553953fe4fbd690a45cecee0339455 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,67 +1,67 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "64": { "BLOCK_SIZE_M": 64, @@ -73,25 +73,25 @@ }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,31 +99,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, @@ -141,6 +141,6 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index ca7f32b9552b479dc05495792b7e426db5eb1b56..2583b5a3441caf622e1e8be428ad28f4f5ee786b 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,57 +1,57 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 5 + "num_stages": 3 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, @@ -59,39 +59,39 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,23 +99,23 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -131,7 +131,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -139,8 +139,8 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 5acea242cc0ad094cba8ee5f568ff88afb1b41ae..baa64f8d3d141c3ce9b60dea9be8a09c9060acd2 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,65 +1,65 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 4 + "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "24": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -69,21 +69,21 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,13 +99,13 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -131,15 +131,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/README.md b/vllm/model_executor/layers/quantization/utils/configs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1110ced4fa063fd00c92aadcefdaa736e2762b7f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/README.md @@ -0,0 +1,3 @@ +# Quantization Kernel Config + +Use scripts under `benchmarks/kernels/` to generate these config files. diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 278ee5232f47e640e5b72aa1110f307410e196cd..9889808f0760f178dba91a03486b040e879f8109 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,9 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Optional import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) + +logger = init_logger(__name__) + + +class FlashinferMoeBackend(Enum): + TENSORRT_LLM = "TensorRT-LLM" + CUTLASS = "CUTLASS" + def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): @@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: layer.register_parameter( 'output2_scales_scalar', torch.nn.Parameter(output2_scales, requires_grad=False)) + layer.register_parameter( + 'w2_input_scale_inv', + torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False)) + + +def build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, +) -> mk.FusedMoEPrepareAndFinalize: + """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" + use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False + return FlashInferCutlassMoEPrepareAndFinalize( + use_dp, a1_gscale=layer.w13_input_scale) + + +def select_cutlass_fp8_gemm_impl( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, + out_dtype: Optional[torch.dtype] = None, +) -> mk.FusedMoEPermuteExpertsUnpermute: + """Return a GEMM *experts* implementation for fused-MoE layers""" + + from vllm.model_executor.models.llama4 import Llama4MoE + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + "FusedMoE flashinfer kernels are only supported for Llama4" + + if moe is not None: + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=moe.in_dtype, + quant_dtype=torch.float8_e4m3fn, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + + assert out_dtype is not None, ( + "If moe config is None, out_dtype must be passed") + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=out_dtype, + quant_dtype=torch.float8_e4m3fn, + ) + + +def flashinfer_cutlass_moe_fp8( + hidden_states: torch.Tensor, + layer: torch.nn.Module, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + fused_experts = mk.FusedMoEModularKernel( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, + layer=layer), + select_cutlass_fp8_gemm_impl(moe=None, + layer=layer, + out_dtype=hidden_states.dtype)) + + return fused_experts( + hidden_states, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +def get_flashinfer_moe_backend() -> FlashinferMoeBackend: + flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND + if flashinfer_moe_backend == "throughput": + return FlashinferMoeBackend.CUTLASS + elif flashinfer_moe_backend == "latency": + return FlashinferMoeBackend.TENSORRT_LLM + + allowed_backends = ["throughput", "latency"] + raise ValueError( + f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" + f" expected one of {allowed_backends}") diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 61fd7e3a050119cb537fae929aef573d35e7fdc8..17177f40a8f12b195a6ed0c14e355f33b7245694 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -20,8 +20,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -109,19 +110,6 @@ def dispatch_w8a8_blockscale_func( return w8a8_block_fp8_matmul -def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): - """ - Check if DeepGEMM should be used based on the output dtype and weight shape. - DeepGEMM is only supported for bfloat16 output dtype and weights with shape - divisible by 128. - """ - - return (current_platform.is_cuda() - and current_platform.is_device_capability(90) and has_deep_gemm() - and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 - and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - - # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -140,7 +128,7 @@ def apply_w8a8_block_fp8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm(output_dtype, weight): + if should_use_deepgemm_for_fp8_linear(output_dtype, weight): input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] @@ -151,7 +139,9 @@ def apply_w8a8_block_fp8_linear( column_major_scales=True, ) + # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, @@ -396,7 +386,7 @@ def per_token_group_quant_fp8( scaling factor. """ if use_ue8m0 is None: - use_ue8m0 = is_blackwell_deep_gemm_e8m0_used() + use_ue8m0 = is_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index db82b0def1653c60d436c530be4b37fab30b5d92..4fbd0f5c4efff76b72968f2b216350f5597e6d10 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy +from fractions import Fraction from typing import Optional, Union import regex as re @@ -29,7 +30,7 @@ def override_config(config: QuantizationConfig, prefix: str): if isinstance(desc_act, bool): config.desc_act = desc_act - config.pack_factor = 32 // config.weight_bits # packed into int32 + config.pack_factor = Fraction(32, config.weight_bits) # packed into int32 if config.get_name() == "gptq_marlin": is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) if isinstance(is_sym, bool): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index 8a64bebae04c9b52430d6c4f76dfe79458989c9a..0000000000000000000000000000000000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: list[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: list[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: list[int] = [] - for i in range(32): - perm1: list[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 48f9cc3737e479f1dd2e79fef91250b0e46a1d19..3de928fea7202bdc75b6e35f433fc05f076680b8 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -66,11 +66,10 @@ def _can_support_mxfp4(use_grouped_topk: bool = False, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None): return not (use_grouped_topk or topk_group or num_expert_group - or expert_map or custom_routing_function - or e_score_correction_bias or apply_router_weight_on_input - or scoring_func != "softmax" or activation != "swigluoai" - or expert_load_view or logical_to_physical_map - or logical_replica_count) + or custom_routing_function or e_score_correction_bias + or apply_router_weight_on_input or scoring_func != "softmax" + or activation != "swigluoai" or expert_load_view + or logical_to_physical_map or logical_replica_count) def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6b21c918f463c522bfc6e955c4d840a2d14636 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + + try: + from flashinfer import mxfp8_quantize + except ImportError as err: + raise ImportError("The package `flashinfer` is required to do " + "MX-FP8 quantization. Please install it with" \ + "`pip install flashinfer`") from err + + return mxfp8_quantize(x, is_sf_swizzled_layout=False) diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..00d3def1db81e40da3d36333c3f10a8891cc8a8c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +import torch + +# TYPE_CHECKING is used for static type analysis to prevent circular imports. +if TYPE_CHECKING: + from types import ModuleType + +# 1. Create a global variable as a placeholder for the module +_petit_kernel: Optional["ModuleType"] = None + +_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with " + "`pip install petit-kernel`.") + + +def _import_petit_kernel() -> "ModuleType": + """ + A helper function to handle the lazy import. + The first time this function is called, it will import the petit_kernel + library and store it in the global _petit_kernel variable. + Subsequent calls will return the already-loaded module directly. + """ + global _petit_kernel + if _petit_kernel is not None: + return _petit_kernel + + try: + import petit_kernel + _petit_kernel = petit_kernel + return _petit_kernel + except ImportError: + # The 'from None' syntax prevents chaining the original ImportError, + # making the traceback cleaner. + raise ImportError(_PETIT_INSTALL_MSG) from None + + +# The _require_petit function can now be a simple alias for consistency. +_require_petit = _import_petit_kernel + + +def _check_petit_nvfp4_supported( + quant_method: str, + group_size: Optional[int]) -> tuple[bool, Optional[str]]: + if quant_method != "NVFP4": + return ( + False, + ("Petit currently only supports: NVFP4 quantizations in sglang. " + "Please check the `hf_quant_config.json` file for your model's " + "quant configuration."), + ) + if group_size is not None and group_size != 16: + return ( + False, + "Petit currently only supports: group_size=16 quantizations.", + ) + return (True, None) + + +def verify_petit_nvfp4_supported(quant_method: str, + group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported( + quant_method, group_size) + if not supported: + assert error_msg is not None + raise ValueError(error_msg) + + +def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + # 2. Call _import_petit_kernel() to trigger (or get) the import. + petit_kernel = _import_petit_kernel() + + # Repack weights to petit format + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + qweight = layer.weight.view(torch.int32).contiguous() + + # 3. Call functions through the imported module variable. + petit_qweight = petit_kernel.repack_nvfp4(qweight, + size_n=part_size_n, + size_k=part_size_k) + layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) + + # Permute scales + weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale, + size_k=part_size_k, + size_n=part_size_n) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + +def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # Trigger (or get) the import here as well. + petit_kernel = _import_petit_kernel() + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + # TODO: Use auto-tuning to find the performant solution_id + # Call the function via the module variable. + output = petit_kernel.mul_nvfp4_a16( + a=reshaped_x, + b=weight, + s=weight_scale, + global_scale=weight_scale_2, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + solution_id=-1, + ) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 428e9e99aa8815d3fa6103155402d22b6f07d94a..6154fca2e416df7e9811efd95c267ee01c2b3a9a 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -2,18 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" from collections.abc import Mapping +from dataclasses import dataclass from types import MappingProxyType from typing import ClassVar, NamedTuple, Optional import numpy import torch +from torch import fx from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + # Use proxy as NamedTuple direct subclasses cannot have static members class _GroupShape(NamedTuple): @@ -36,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +@dataclass(frozen=True) +class ScaleDesc: + """ + Class for describing a single quantization scaling factor. + dtype: data type of the scale + static: static scale if True, dynamic if False + group_shape: group shape of the scale + """ + dtype: torch.dtype + static: bool + group_shape: GroupShape + + def __str__(self): + group_shape = ('per_tensor' + if self.group_shape == GroupShape.PER_TENSOR else + ('per_token' if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape))) + + return (f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'static' if self.static else 'dynamic'},{group_shape}") + + +@dataclass(frozen=True) +class QuantKey: + """ + Class for identifying the type of quantization. + dtype: quantized data type + scale: scale descriptor + scale2: second-level scale descriptor + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + scale: ScaleDesc + scale2: Optional[ScaleDesc] = None + symmetric: bool = True + + def __str__(self): + scale2_str = f"scale2({self.scale2})," if self.scale2 else "" + return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," + f"scale({self.scale}),{scale2_str}" + f"{'a' if not self.symmetric else ''}symmetric)") + + +kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR) +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) + +kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) + +kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) + +kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Quant = QuantKey(FP4_DTYPE, + scale=kNvfp4GroupScale, + scale2=kStaticTensorScale) + + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent @@ -386,89 +447,6 @@ def gptq_quantize_weights(w: torch.Tensor, return w_ref, w_q, w_s, g_idx, rand_perm -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ - f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / - s_channel).to(dtype=torch.half) - else: - max_q_val = 2**(num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= (2**(8 - num_bits)) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device @@ -637,8 +615,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() if scale_ndim == 2: - return swizzled.reshape(M, K) - return swizzled.reshape(B, M, K) + return swizzled.reshape(M_padded, K_padded) + return swizzled.reshape(B, M_padded, K_padded) def cutlass_fp4_supported() -> bool: diff --git a/vllm/model_executor/layers/quantization/utils/w4a8_utils.py b/vllm/model_executor/layers/quantization/utils/w4a8_utils.py index 9eb107924b4c50fd168d54711770e1ab47faa69a..4163f8e894d50f85b65ca6d30e475ea9c94cfbe6 100644 --- a/vllm/model_executor/layers/quantization/utils/w4a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w4a8_utils.py @@ -54,11 +54,12 @@ def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8): orig_device = q_w.device - q_w = q_w.cpu().numpy().astype(np.uint32) - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + q_w = q_w.contiguous().to(torch.int32) + M, N = q_w.shape + assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})" + q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device) for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << 4 * i - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + q_packed += q_w[:, i::pack_factor] << (4 * i) return q_packed diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8b9a64f6ee8571cda1ebd14cc3543ec48d60377d..544cd8bd5af95deb605dc3db2b95baec19e487ac 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -13,9 +13,13 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform + from vllm.utils import W8a8GetCacheJSON from lmslim.layers.gemm.int8_utils import per_token_quant_int8 +from vllm.utils import direct_register_custom_op +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer + # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = None @@ -158,13 +162,23 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, return output.view(*output_shape) -def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: +def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: list, **kwargs) -> torch.Tensor: + + return flashinfer_scaled_fp8_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + +def rocm_per_tensor_w8a8_scaled_mm_impl( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -177,10 +191,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a=scale_a, scale_b=scale_b, bias=bias) + return output + +def rocm_per_tensor_w8a8_scaled_mm_fake( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: + return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), + dtype=out_dtype) + + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) +direct_register_custom_op( + op_name="rocm_per_tensor_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_w8a8_scaled_mm_impl, + mutates_args=[], + fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, + dispatch_key=current_platform.dispatch_key, +) + + def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, @@ -207,8 +249,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: + input_2d: torch.Tensor, output_shape: list, + **kwargs) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. # For now it has only been validated on ROCm platform. @@ -279,16 +321,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( - cutlass_fp8_supported: bool, per_tensor_weights: bool, + preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool) -> Callable[..., torch.Tensor]: - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if cutlass_fp8_supported: - return cutlass_w8a8_scaled_mm if per_tensor_weights and per_tensor_activations: - if current_platform.is_rocm(): + if preferred_backend == "rocm": return rocm_per_tensor_w8a8_scaled_mm + if preferred_backend == "flashinfer": + return flashinfer_w8a8_scaled_mm + if preferred_backend == "cutlass": + return cutlass_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if preferred_backend == "cutlass" or preferred_backend == "flashinfer": + return cutlass_w8a8_scaled_mm + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) if not per_tensor_weights and not per_tensor_activations \ and USE_ROWWISE_TORCH_SCALED_MM: @@ -310,10 +358,20 @@ class Fp8LinearOp: def __init__(self, act_quant_static: bool, - cutlass_fp8_supported: bool = cutlass_fp8_supported(), act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None): - self.cutlass_fp8_supported = cutlass_fp8_supported + pad_output: Optional[bool] = None, + force_fp8_e4m3fnuz: bool = False): + if current_platform.is_rocm(): + self.preferred_backend = "rocm" + elif current_platform.is_cuda( + ) and not force_fp8_e4m3fnuz and cutlass_fp8_supported(): + if has_flashinfer() and current_platform.has_device_capability( + 100): + self.preferred_backend = "flashinfer" + else: + self.preferred_backend = "cutlass" + else: + self.preferred_backend = "torch" # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -323,8 +381,7 @@ class Fp8LinearOp: if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = config.level < CompilationLevel.PIECEWISE and \ - not cutlass_fp8_supported and \ - not current_platform.is_rocm() + self.preferred_backend == "torch" self.output_padding = 17 if pad_output else None self.act_quant_static = act_quant_static @@ -369,9 +426,9 @@ class Fp8LinearOp: per_tensor_activations = (x_scale.numel() == 1) # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations) + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend, + per_tensor_weights, + per_tensor_activations) return w8a8_scaled_mm_func(qinput=qinput, weight=weight, diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..05322e56f2620c4b1a8723af92a1d2c739200bca --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from .common import apply_rotary_emb_dispatch +from .mrope import MRotaryEmbedding + + +class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): + """3D rotary positional embedding. 3D is t:time h:height w:width""" + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + section_h = self.mrope_section[0] # 22 + section_w = self.mrope_section[1] # 22 + section_t = self.mrope_section[2] # 20 + assert section_h == section_w + # Split according to [h w h w h w h w... t t t...] + section_cos_t = cos[..., -section_t:] + section_cos_h = cos[..., :section_h + section_w:2] + section_cos_w = cos[..., 1:section_h + section_w:2] + + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ + 1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], + dim=-1).reshape(cos_h.shape[:-1] + + (cos_h.shape[-1] * 2, )) + cos = torch.cat([cos_hw, cos_t], dim=-1) + + section_sin_t = sin[..., -section_t:] + section_sin_h = sin[..., :section_h + section_w:2] + section_sin_w = sin[..., 1:section_h + section_w:2] + + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ + 1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], + dim=-1).reshape(sin_h.shape[:-1] + + (sin_h.shape[-1] * 2, )) + sin = torch.cat([sin_hw, sin_t], dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a091cfb743291642caf640a94bfcad687c521fec..5686ec7b35de88aa80722acd7296ae1bf899ae00 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -393,6 +393,15 @@ class MRotaryEmbedding(RotaryEmbedding): context_len=context_len, seq_len=seq_len, ) + elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: + return cls._ernie_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, @@ -513,6 +522,120 @@ class MRotaryEmbedding(RotaryEmbedding): len(input_tokens)).item() return llm_positions, mrope_position_delta + @classmethod + def _ernie_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for Ernie VL.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_conv_size, w // spatial_conv_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = (t // + temporal_conv_size, + h // + spatial_conv_size, + w // + spatial_conv_size) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + @classmethod def _vl_get_input_positions_tensor( cls, @@ -547,12 +670,18 @@ class MRotaryEmbedding(RotaryEmbedding): image_index, video_index = 0, 0 for _ in range(image_nums + video_nums): video_second_per_grid_t = 0.0 - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 else: ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index d9c40cb2a79acab13d749257c17e2508b03d7fc6..03970ff9739d85077a998ed52ccd74530a621b2e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -142,6 +142,12 @@ direct_register_custom_op( ) +def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype): + return (torch._C._cpu._is_amx_tile_supported() + and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 + and n % 16 == 0) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 2b8e4427591c1530cc43b15964a15ea8b7114e7e..34b8d8e4ed6228df467827df3f62c084bc5b8a04 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -207,16 +207,21 @@ class DefaultModelLoader(BaseModelLoader): ) if current_platform.is_tpu(): - # In PyTorch XLA, we should call `xm.mark_step` frequently so that - # not too many ops are accumulated in the XLA program. - import torch_xla.core.xla_model as xm + from vllm.platforms.tpu import USE_TPU_COMMONS - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - xm.mark_step() + if not USE_TPU_COMMONS: + # In PyTorch XLA, we should call `xm.mark_step` + # requently so that not too many ops are accumulated + # in the XLA program. import torch_xla.core.xla_model + # as xm + import torch_xla.core.xla_model as xm - weights_iterator = _xla_weights_iterator(weights_iterator) + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 21655b0c69bb4b46a1f3be2ee58e4ae665ec2a90..9877cb3b7c06e3caddaf5c7adb2eab49003f67a8 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( - get_gguf_extra_tensor_names, gguf_quant_weights_iterator) + get_gguf_extra_tensor_names, get_gguf_weight_type_map, + gguf_quant_weights_iterator) class GGUFModelLoader(BaseModelLoader): @@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader): local_model_path, gguf_weights_map): model_config.hf_config.update({"tie_word_embeddings": True}) + weight_type_map = get_gguf_weight_type_map(model_config.model, + gguf_weights_map) + + # filter out unquantized modules to skip + unquant_names = [ + name.removesuffix(".weight") + for name, weight_type in weight_type_map.items() + if weight_type == "F32" and name.endswith(".weight") + ] + vllm_config.quant_config.unquantized_modules.extend(unquant_names) + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index b44c165397d02cc5f3079e1e24742b307228c5ee..a70cdeb483e67d6a91b227e23d1d172b91394183 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -98,14 +98,15 @@ class TPUModelLoader(DefaultModelLoader): # Check parameters for name, param in model.named_parameters(): - assert param.device.type == device_type, f"Parameter {name} is on \ - {param.device.type} instead of {device_type}" + assert param.device.type == device_type, ( + f"Parameter {name} is on {param.device.type} " + f"instead of {device_type}") # Check buffers for name, buffer in model.named_buffers(): - assert buffer.device.type == device_type, \ - f"Buffer {name} is on {buffer.device.type} instead of \ - {device_type}" + assert buffer.device.type == device_type, ( + f"Buffer {name} is on {buffer.device.type} " + f"instead of {device_type}") for module in model.modules(): if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'): diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index eb1162cbb59f0817fc26df2450b835b0a126dffe..5f16bab42b6eaba8f6bc8f7f4713e4145f80ca53 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -31,9 +31,7 @@ from vllm.utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer -except (ImportError, OSError): - # see https://github.com/run-ai/runai-model-streamer/issues/26 - # OSError will be raised on arm64 platform +except ImportError: runai_model_streamer = PlaceholderModule( "runai_model_streamer") # type: ignore[assignment] SafetensorsStreamer = runai_model_streamer.placeholder_attr( @@ -280,33 +278,48 @@ def download_weights_from_hf( Returns: str: The path to the downloaded model weights. """ + assert len(allow_patterns) > 0 local_only = huggingface_hub.constants.HF_HUB_OFFLINE if not local_only: - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] + # Attempt to reduce allow_patterns to a single pattern + # so we only have to call snapshot_download once. + try: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, + detail=False, + revision=revision) + + # Use the first pattern found in the HF repo's files. + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] break + except Exception as e: + logger.warning( + "Failed to get file list for '%s'. Trying each pattern in " + "allow_patterns individually until weights have been " + "downloaded. Error: %s", model_name_or_path, e) logger.info("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): start_time = time.perf_counter() - hf_folder = snapshot_download( - model_name_or_path, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - cache_dir=cache_dir, - tqdm_class=DisabledTqdm, - revision=revision, - local_files_only=local_only, - ) + for allow_pattern in allow_patterns: + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_pattern, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=local_only, + ) + # If we have downloaded weights for this allow_pattern, + # we don't need to check the rest. + if any(Path(hf_folder).glob(allow_pattern)): + break time_taken = time.perf_counter() - start_time if time_taken > 0.5: logger.info("Time spent downloading weights for %s: %.6f seconds", @@ -582,6 +595,18 @@ def get_gguf_extra_tensor_names( return [gguf_to_hf_name_map[key] for key in extra_keys] +def get_gguf_weight_type_map( + gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]: + """ + Return GGUF mapped weight's name and its quant type + """ + reader = gguf.GGUFReader(gguf_file) + return { + gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name + for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map + } + + def gguf_quant_weights_iterator( gguf_file: str, gguf_to_hf_name_map: dict[str, str] ) -> Generator[tuple[str, torch.Tensor], None, None]: diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1dbe70f84a6264e473ad3ab1eb33ff7ceafb760b..50c2cd97f3d0993bd54daecd95ae7635f5f9653b 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig +from vllm.transformers_utils.config import (get_hf_file_bytes, + get_hf_file_to_dict) from .interfaces_base import VllmModelForPooling, is_pooling_model if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig _T = TypeVar("_T", bound=type[nn.Module]) +logger = init_logger(__name__) + _GENERATE_SUFFIXES = [ "ForCausalLM", "ForConditionalGeneration", @@ -24,6 +30,96 @@ _GENERATE_SUFFIXES = [ ] +def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: + """Load Sentence-Transformers Dense projection layers.""" + + try: + modules = get_hf_file_to_dict("modules.json", model_config.model, + model_config.revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules + if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + module = dense_modules[0] + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model_config.model, + model_config.revision) + if not layer_config: + return None + + linear = nn.Linear(layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=torch.float32) + + if _load_dense_weights(linear, folder, model_config): + layers = [linear] + if act_name := layer_config.get("activation_function"): + layers.append(get_act_fn(act_name)) + return nn.Sequential(*layers).to(dtype=torch.float32) + + except Exception: + logger.exception("ST projector loading failed") + + return None + + +def _load_dense_weights(linear: nn.Linear, folder: str, + model_config: "ModelConfig") -> bool: + """Load weights using vLLM's weight_loader pattern.""" + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) + + for filename in ["model.safetensors", "pytorch_model.bin"]: + file_path = f"{folder}/{filename}" if folder else filename + + try: + file_bytes = get_hf_file_bytes(file_path, model_config.model, + model_config.revision) + if not file_bytes: + continue + + if filename.endswith(".safetensors"): + from safetensors.torch import load as load_safetensors + state_dict = load_safetensors(file_bytes) + else: + import io + state_dict = torch.load(io.BytesIO(file_bytes), + map_location="cpu", + weights_only=True) + + for weight_key in ["weight", "linear.weight", "dense.weight"]: + if weight_key in state_dict: + weight_loader = getattr(linear.weight, "weight_loader", + default_weight_loader) + weight_loader(linear.weight, + state_dict[weight_key].to(torch.float32)) + + bias_key = weight_key.replace("weight", "bias") + if linear.bias is not None and bias_key in state_dict: + bias_loader = getattr(linear.bias, "weight_loader", + default_weight_loader) + bias_loader(linear.bias, + state_dict[bias_key].to(torch.float32)) + return True + except Exception: + logger.exception("Failed to load %s", filename) + continue + + return False + + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name @@ -152,7 +248,7 @@ def as_seq_cls_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.linear import RowParallelLinear + from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.pooler import (ClassifierPooler, DispatchPooler, Pooler, PoolingMethod, PoolingType) @@ -168,10 +264,9 @@ def as_seq_cls_model(cls: _T) -> _T: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.score = RowParallelLinear( + self.score = ReplicatedLinear( config.hidden_size, config.num_labels, - input_is_parallel=False, bias=False, params_dtype=torch.float32, quant_config=quant_config, diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py new file mode 100644 index 0000000000000000000000000000000000000000..0de683d2cd060b94d34b4d70a3f49b811c89299e --- /dev/null +++ b/vllm/model_executor/models/apertus.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The Swiss AI Initiative. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate the architectural differences made by +# the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Apertus model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import ApertusConfig + +from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import XIELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class ApertusMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "xielu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only xIELU is supported for now.") + self.act_fn = XIELU() + + def forward(self, x): + x, _ = self.up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class ApertusAttention(nn.Module): + + def __init__( + self, + config: ApertusConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + # Phi models introduced a partial_rotary_factor parameter in the config + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) + + sliding_window = None + if layer_types := getattr(config, "layer_types", None): + is_sliding = layer_types[layer_idx] == "sliding_attention" + if is_sliding: + sliding_window = config.sliding_window + + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.q_norm(q.contiguous().view(-1, self.head_dim)).view_as(q) + k = self.k_norm(k.contiguous().view(-1, self.head_dim)).view_as(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb(self, config: ApertusConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "apertus": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.partial_rotary_factor * self.head_dim), + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + + +class ApertusDecoderLayer(nn.Module): + + def __init__( + self, + config: ApertusConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, 'qkv_bias'): + attention_bias = config.qkv_bias + + # Apertus defaults to causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = ApertusAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = ApertusMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.feedforward_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_layernorm(hidden_states) + else: + hidden_states, residual = self.attention_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class ApertusModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.aux_hidden_state_layers = tuple[int, ...]() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings" + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + return ApertusModel(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 4cf73e2e0ea56ac13dcc9a618e3938b28a89948b..13ed4da0602ad5cb054d887d619dc976aec5039e 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -9,6 +9,7 @@ # activation. from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -243,7 +244,7 @@ class ArceeModel(nn.Module): aux_hidden_states: list[torch.Tensor] = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer)): if idx in self.aux_hidden_state_layers: aux_hidden_states.append( hidden_states + diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 4693c9487a8bf1e1cda84bf1b2080c4b84c4e27f..c566611266af7b8a7d32365ebab4579303fa572b 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Snowflake Arctic model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -403,7 +404,7 @@ class ArcticModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index e1368a3f6478ada06be2229dbb86dd995ed9d373..1c7960fa3e0a501d98505dc45cab4502499744d1 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -470,7 +470,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 5cd74bbba4827fcd7a39482811bcbe14e7d1c06e..687c82ded9d0ac6318e43f65e21065a3335ad01b 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -18,7 +18,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -242,7 +242,7 @@ class AyaVisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token @@ -250,8 +250,7 @@ class AyaVisionMultiModalProcessor( image_processor = hf_processor.image_processor def get_replacement(item_idx: int): - images: ImageProcessorItems = mm_items.get("image", - ImageProcessorItems) + images = mm_items.get_items("image", ImageProcessorItems) image_size: ImageSize = images.get_image_size(item_idx) num_patches = self.info.get_num_patches( image_width=image_size.width, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index ea4122280b03b7cc8051b66c42b6ea5d6d99a3fb..41d9c7f10eb1c5846867ecbd2fabaa5ff3e4bb9a 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -22,6 +22,7 @@ """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -334,7 +335,7 @@ class BaiChuanModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 23cab3509ca8205ca0c9aaf16b300bb0a68235d2..a42640cef9d4456581f730fea9c85d27a1629a08 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only BailingMoE model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -359,8 +360,7 @@ class BailingMoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( hidden_states, position_ids, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index e2cd31af5390aefd3d8053f33f441b4b41e8fc99..a72bbdebe531714102eb2f5a1ee08e026d99ce07 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -345,8 +345,7 @@ class BambaModel(nn.Module): residual = None num_attn = 0 - for i in range(len(self.layers)): - layer = self.layers[i] + for i, layer in enumerate(self.layers): if isinstance(layer, BambaAttentionDecoderLayer): num_attn += 1 diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6638f06f98261129efa639aca93504715dcde978..8f23439655ed7257457a98d620a9e80e45b800a8 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import BertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -24,12 +24,12 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import (SupportsCrossEncoding, SupportsQuant, - default_pooling_type) +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.qkv_proj") - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -528,9 +527,9 @@ def _encode_token_type_ids(input_ids: torch.Tensor, def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - ids_mask = torch.ones(input_ids.shape, - dtype=torch.int32, - device=input_ids.device) << TOKEN_TYPE_SHIFT + ids_mask = torch.ones_like(input_ids, + dtype=torch.int32, + device=input_ids.device) << TOKEN_TYPE_SHIFT tokens_mask = ids_mask.bitwise_not() token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index e18b7b7ffabab5eac834e114c2a8794d3340882d..3be7e11d947d5c466e4fce9032c704691bf45775 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -27,13 +27,17 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (SupportsQuant, - default_pooling_type) -from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + maybe_prefix) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler +from .bert import BertPooler +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type + class BertWithRopeEmbedding(nn.Module): @@ -119,14 +123,13 @@ class BertWithRopeAttention(nn.Module): self.rotary_emb = get_rope(**rotary_kwargs) - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") self.out_proj = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, @@ -406,9 +409,14 @@ class BertWithRopeEncoder(nn.Module): class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + add_pooling_layer: bool = False): super().__init__() self.vllm_config = vllm_config + self.add_pooling_layer = add_pooling_layer self.config = vllm_config.model_config.hf_config self.embeddings = BertWithRopeEmbedding(self.config) self.encoder = BertWithRopeEncoder( @@ -416,6 +424,7 @@ class BertWithRope(nn.Module, SupportsQuant): bias=getattr(self.config, "bias", True), rotary_kwargs=self.config.rotary_kwargs, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(self.config) if add_pooling_layer else None def forward( self, @@ -448,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if not self.add_pooling_layer and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -508,8 +517,8 @@ class GteNewModel(BertWithRope): "attention.o_proj": "attn.out_proj", }) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) # GteNewModel only gate_up_proj does not have bias. # Hack method learned from vllm/model_executor/models/glm.py @@ -614,3 +623,65 @@ class JinaRobertaModel(BertWithRope): torch.Tensor]]) -> set[str]: weights = self.jina_merge_lora_weights(weights) return super().load_weights(weights) + + +@default_pooling_type("CLS") +class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.new = GteNewModel(vllm_config=vllm_config, + prefix=prefix, + add_pooling_layer=True) + self.classifier = RowParallelLinear(config.hidden_size, + config.num_labels, + input_is_parallel=False, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "classifier"), + return_bias=False) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return self.new(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8e3505f872eb2529047748e92c1183700b608e58..2f2b880bb0e144d84149bcbe872ed6a9a0da9052 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, @@ -492,7 +492,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 9743512a73f28675fc061dae62b89b70e7143c01..ac411b3b66a338e3be185865cd440cf9ba5cc19c 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -20,6 +20,7 @@ """Inference-only BLOOM model compatible with HuggingFace weights.""" import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -47,7 +48,7 @@ from vllm.sequence import IntermediateTensors from vllm import _custom_ops as ops from vllm.model_executor.utils import pad_weight, gemm_bank_conf -from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only +from .interfaces import SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -293,7 +294,7 @@ class BloomModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -369,7 +370,7 @@ class BloomModel(nn.Module): return loaded_params -class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): +class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 8d705f40ce8ff3f6c4a9d532211bb92a47d49ad9..28a1a66c232917b78f1690cb1c34b8b5a91d9492 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -3,6 +3,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property +from itertools import islice from typing import Annotated, Any, Literal, Optional, Union import torch @@ -31,7 +32,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -151,7 +152,7 @@ class ChameleonMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -914,7 +915,7 @@ class ChameleonModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 2e7a6f5a1013562a1bdd667ae42e8e497eed7e2b..5b4e002fd37f545280639238e2883fb2483830c8 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -5,6 +5,7 @@ """Inference-only ChatGLM model compatible with THUDM weights.""" import json from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -293,7 +294,7 @@ class GLMTransformer(nn.Module): hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> Union[torch.Tensor, IntermediateTensors]: - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(hidden_states=hidden_states, position_ids=position_ids) diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index f17583768f7959be33f82199fad56ade4d38e6c4..179cc2af8eb3fefee8d8f48a91d6b3c509845755 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -10,6 +10,8 @@ import torch from torch import nn from transformers import BatchFeature, PretrainedConfig from transformers.models.cohere2_vision import Cohere2VisionConfig +from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501 + get_optimal_tiled_canvas) from transformers.models.cohere2_vision.processing_cohere2_vision import ( Cohere2VisionProcessor) @@ -21,7 +23,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -150,14 +152,48 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo): max_patches = image_processor.max_patches return ImageSize(height=height * max_patches, width=width) - def get_num_patches(self, image_width: int, image_height: int) -> int: + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Cohere2VisionProcessor], + ) -> int: """ Calculate the number of image patches for a given image. Uses the HF processor to determine the actual number of patches. """ - return self.get_hf_processor( - ).image_processor.get_number_of_image_patches(image_height, - image_width, {}) + if processor is None: + processor = self.get_hf_processor() + + image_processor = processor.image_processor + + # The current implementation of get_number_of_image_patches + # is incorrect, so we patch it here. + # TODO: Revert once + # https://github.com/huggingface/transformers/pull/40312 is released. + # return image_processor.get_number_of_image_patches(image_height, + # image_width, {}) + + min_patches = image_processor.min_patches + max_patches = image_processor.max_patches + patch_size = image_processor.size + crop_to_patches = image_processor.crop_to_patches + + if not crop_to_patches: + return 1 + + num_columns, num_rows = get_optimal_tiled_canvas( + (image_height, image_width), + (patch_size["height"], patch_size["width"]), + min_patches, + max_patches, + ) + num_patches = num_columns * num_rows + if num_patches > 1: + num_patches += 1 # Thumbnail image + + return num_patches class Cohere2VisionDummyInputsBuilder( @@ -208,6 +244,8 @@ class Cohere2VisionMultiModalProcessor( # Ensure num_patches is available for proper tensor splitting if "num_patches" not in processed_outputs and ( images := mm_data.get("images")) is not None: + hf_processor = self.info.get_hf_processor(**mm_kwargs) + # Fallback calculation if HF processor didn't provide num_patches parsed_images = self._get_data_parser().parse_mm_data({ "image": @@ -217,8 +255,9 @@ class Cohere2VisionMultiModalProcessor( num_patches = [ self.info.get_num_patches( image_width=parsed_images.get_image_size(i).width, - image_height=parsed_images.get_image_size(i).height) - for i in range(len(parsed_images)) + image_height=parsed_images.get_image_size(i).height, + processor=hf_processor, + ) for i in range(len(parsed_images)) ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -241,29 +280,29 @@ class Cohere2VisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token + img_tokens_per_tile = int(hf_processor.patch_size**2) img_line_break_token = hf_processor.img_line_break_token boi_token = hf_processor.boi_token eoi_token = hf_processor.eoi_token def get_replacement(item_idx: int): - images: ImageProcessorItems = mm_items.get("image", - ImageProcessorItems) + images = mm_items.get_items("image", ImageProcessorItems) image_size: ImageSize = images.get_image_size(item_idx) - num_patches = self.info.get_num_patches(image_size.height, - image_size.width) - img_tokens_per_tile = int(hf_processor.patch_size**2) - single_tile_tokens = image_token * img_tokens_per_tile + \ - img_line_break_token - img_string = f"{boi_token}\ - {single_tile_tokens * num_patches}\ - {eoi_token}" + num_patches = self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + patch_tokens = (image_token * img_tokens_per_tile + + img_line_break_token) + repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}" - return PromptUpdateDetails.select_text(img_string, image_token) + return PromptUpdateDetails.select_text(repl, image_token) return [ PromptReplacement( @@ -311,7 +350,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=["Cohere2ForCausalLM"]) + architectures=config.text_config.architectures) @property def dtype(self): diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 4dd84b8f8fdd56248fe872b40b3eabd97e01d50f..7f87e31abdcd3c081375a1901efabaee6473cdc3 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -23,6 +23,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -322,7 +323,7 @@ class CohereModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 882df7e8162c56e7bcd399903c1fa9f8c0e0ddd5..377b7bf26a07a247ba7b1aeb6bd848d5c3ca689c 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,6 +4,7 @@ from copy import deepcopy from typing import TYPE_CHECKING import vllm.envs as envs +from vllm.config.compilation import CUDAGraphMode from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv @@ -275,6 +276,43 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig): "%d for performance.", 1024) +class MambaModelConfig(VerifyAndUpdateConfig): + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Enable FULL_AND_PIECEWISE cuda graph mode by default (required + to get good performance for mamba layers in V1). + + Args: + vllm_config: vLLM Config + """ + + if not envs.VLLM_USE_V1: + return + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + + # TODO(tdoublep): remove once prefix caching is enabled + cache_config.enable_prefix_caching = False + logger.info("Hybrid or mamba-based model detected: disabling prefix " + "caching since it is not yet supported.") + + # TODO(tdoublep): remove as full cuda graph support is added + FCG_NOT_SUPPORTED_MODELS = [ + "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM" + ] + + if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS + and compilation_config.cudagraph_mode is None): + logger.info( + "Hybrid or mamba-based model detected: setting cudagraph mode " + "to FULL_AND_PIECEWISE in order to optimize performance.") + compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE + + class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod @@ -293,6 +331,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): if not envs.VLLM_USE_V1: return + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -365,6 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, + "GteNewForSequenceClassification": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, @@ -374,4 +416,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "JambaForSequenceClassification": JambaForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, "GptOssForCausalLM": GptOssForCausalLMConfig, + "MambaForCausalLM": MambaModelConfig, + "Mamba2ForCausalLM": MambaModelConfig, + "FalconMambaForCausalLM": MambaModelConfig, } diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index e74d90e0b1d7dd05811dd661b199009f7ea1a495..519cd522213b29bb9e6904eb801c7cab6c49cc25 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -359,7 +360,7 @@ class DbrxModel(nn.Module): else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] - for block in self.blocks[self.start_layer:self.end_layer]: + for block in islice(self.blocks, self.start_layer, self.end_layer): hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index ca1962161270631274049746fd282da5ed4be90e..1407e0c198968e46458401de48529f4f26b79822 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only Deepseek model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -51,7 +52,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -387,7 +388,7 @@ class DeepseekModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -448,7 +449,11 @@ class DeepseekModel(nn.Module): return loaded_params -class DeepseekForCausalLM(nn.Module, SupportsPP): +class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -493,4 +498,4 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9c83cf610001a2863cdac024c9334f3aacb91b --- /dev/null +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, + DeepseekV3ForCausalLM) +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from .utils import AutoWeightsLoader, maybe_prefix + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + ) -> None: + super().__init__() + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = self.config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + DeepseekV2DecoderLayer( + self.config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ) for i in range(self.config.num_hidden_layers) + ]) + + self.fc = nn.Linear( + self.config.model.hidden_size * 2, + self.config.model.hidden_size, + bias=False, + ) + + self.enorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.hnorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + + inputs = torch.cat( + [self.enorm(input_embeds), + self.hnorm(hidden_states)], dim=-1) + hidden_states = self.fc(inputs) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ((param_name == "fused_qkv_a_proj") + and name_mapped not in params_dict): + continue + else: + name = name_mapped + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + quant_config = vllm_config.quant_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num) + + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) + return self.model(input_ids, positions, hidden_states) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0a90c6864372726581da93fb229265e475fb6300..27fd7043876712168a48d4fba47f8fd8309058cf 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -29,6 +29,7 @@ import vllm.envs as envs import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -59,7 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import MixtureOfExperts, SupportsPP +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -132,17 +133,17 @@ class DeepseekV2MoE(nn.Module): prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) + torch.empty(config.n_routed_experts, dtype=torch.float32)) else: self.gate.e_score_correction_bias = None # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) @@ -166,6 +167,7 @@ class DeepseekV2MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, @@ -730,7 +732,7 @@ class DeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -743,7 +745,8 @@ class DeepseekV2Model(nn.Module): return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, + SupportsLoRA): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index e0acca75d9dd60f16c06c5eea7555f521fad7653..5eab02b17151c8a09a22efa59df4fc0d6a970d91 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -21,11 +21,12 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -252,7 +253,7 @@ class DeepseekVL2MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -289,9 +290,8 @@ class DeepseekVL2MultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only @@ -302,7 +302,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return super()._cached_apply_hf_processor( @@ -310,7 +310,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) @@ -408,13 +408,17 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if isinstance(module, nn.Linear): parent, attr_name = self._get_parent_and_attr(vit, name) if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": - new_linear = replace_linear_class(module, "colwise", - quant_config) + new_linear = replace_linear_class(module, + "colwise", + quant_config, + prefix=name) setattr(parent, attr_name, new_linear) elif isinstance(parent, timm.layers.Mlp) and attr_name == "fc2": - new_linear = replace_linear_class(module, "rowwise", - quant_config) + new_linear = replace_linear_class(module, + "rowwise", + quant_config, + prefix=name) setattr(parent, attr_name, new_linear) return vit diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py new file mode 100644 index 0000000000000000000000000000000000000000..c00db52371b68e5d3b14681d57bacb5700052a50 --- /dev/null +++ b/vllm/model_executor/models/donut.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, Optional, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature, NougatProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.swin import SwinModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, + _flatten_embeddings, flatten_bn) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape + + +class MBartDecoderWrapper(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = MBartDecoderWrapper(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + Returns: + Output torch.Tensor + """ + + return self.model(decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + # if self.config.tie_word_embeddings and "embed_tokens" in name: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DonutImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] + + +class DonutProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + return 1 + + +class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_hf_config( + ).encoder.image_size + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor() + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + if isinstance(hf_processor, NougatProcessor): + processed_outputs["input_ids"] = processed_outputs["labels"] + else: + tokenizer = hf_processor.tokenizer + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + pad_token_id = tokenizer.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, + info=DonutProcessingInfo, + dummy_inputs=DonutDummyInputsBuilder) +class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.encoder + self.processor_config = processor_config + self.encoder = SwinModel(config=config.encoder) + + self.decoder = DonutLanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.decoder), + prefix=f"{prefix}.decoder", + ) + self.pad_token_id = config.pad_token_id + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + h, w = self.config.encoder.image_size + return DonutImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": h, + "w": w, + }) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: DonutImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + dtype = next(self.encoder.parameters()).dtype + pixel_values = pixel_values.to(dtype) + return self.encoder(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + ) -> torch.Tensor: + return _flatten_embeddings(multimodal_embeddings) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + inputs_embeds = None + if encoder_input_ids.numel() > 0: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + + hidden_states = self.decoder(input_ids, + positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 5f410c0ae5fb08d3267412722657faaf912ef7d8..a5477af8694b441f6d7e50af00e08324dbc26700 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -25,6 +25,7 @@ # limitations under the License. """Inference-only dots1 model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -136,6 +137,7 @@ class Dots1MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias) if config.n_shared_experts is not None: @@ -391,7 +393,7 @@ class Dots1Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index e0844f86301171ccbe3fb6f76dccf5f0f1dc0bc2..5c82d0e95b6c6eb85a751bd58959092064563689 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -23,6 +23,7 @@ # limitations under the License. """Inference-only ErineMoE model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -419,8 +420,7 @@ class Ernie4_5_MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..d880fc434e20fd17f0210a89ff3bb197b41a35bc --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl.py @@ -0,0 +1,1504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend, current_platform +from vllm.sequence import IntermediateTensors + +from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +_MAX_FRAMES_PER_VIDEO = 16 + +# === Vision Transformer === # + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + apply_rotary_emb = apply_rotary_emb_torch + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + + +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather(gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + +class Ernie4_5_VisionAttention(nn.Module): + """VisionAttention using VLLM framework APIs""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Ernie45-VL does not support {self.attn_backend} backend now." + ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.is_flash_attn_backend: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Ernie4_5_VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +class Ernie4_5_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Ernie4_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.mlp = Ernie4_5_VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Ernie4_5_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1280, + prefix="", + ) -> None: + + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_channels * patch_size * patch_size, + embed_dim, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.to(target_dtype) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class Ernie4_5_VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.inv_freq = 1.0 / theta**( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(input=seq, vec2=self.inv_freq) + return freqs + + +class Ernie4_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + + super().__init__() + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Ernie4_5_VisionPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + prefix=f"{prefix}.patch_embed", + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Ernie4_5_VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + assert (hidden_size == embed_dim + ), "vit's config.hidden must be equal to config.embed_dim" + self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward(self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + num_pad=0) -> torch.Tensor: + + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + + if num_pad > 0: + cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # add batch size + if hidden_states.ndim == 2: + hidden_states = hidden_states.unsqueeze(dim=1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + for i, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + final_output = self.ln(hidden_states) + + if final_output.ndim == 3: + final_output = final_output.squeeze(dim=1) + + return final_output + + def load_weights(self, weights) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# === Vision Inputs === # + + +class Ernie4_5_VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs + + +class Ernie4_5_VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs + +# === Vision Processor === # + + +def round_by_factor(number: Union[int, float], factor: int) -> int: + return round(number / factor) * factor + + +def ceil_by_factor(number: Union[int, float], factor: int) -> int: + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: Union[int, float], factor: int) -> int: + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 16384 * 28 * 28, +): + MAX_RATIO = 200 + if max(height, width) / min(height, width) > MAX_RATIO: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * MAX_RATIO, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * MAX_RATIO, factor) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + +class VariableResolutionResamplerModel(nn.Module): + + def __init__(self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "") -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.config = config + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + self.use_temporal_conv = config.use_temporal_conv + + # compress 2d conv(picture) to 1d + self.spatial_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size) + # compress 3d conv(video) to 1d + self.temporal_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size * self.temporal_conv_size) + + self.spatial_linear1 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear1", + ) + + self.spatial_gelu = nn.GELU() + + self.spatial_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear2", + ) + + self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + if self.use_temporal_conv: + self.temporal_linear1 = ColumnParallelLinear( + self.temporal_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear1", + ) + + self.temporal_gelu = nn.GELU() + + self.temporal_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear2", + ) + + self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + self.mlp = ColumnParallelLinear( + self.spatial_dim, + self.out_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.mlp", + ) + + self.after_norm = RMSNorm(hidden_size=out_dim, + eps=getattr(config, 'rms_norm_eps', 1e-6)) + + def spatial_conv_reshape(self, x, spatial_conv_size): + S, C = x.shape + x = x.reshape([-1, C * (spatial_conv_size**2)]) + return x + + def forward(self, x, grid_thw): + + def fwd_spatial(x): + x = self.spatial_conv_reshape(x, self.spatial_conv_size) + + x, _ = self.spatial_linear1(x) + x = self.spatial_gelu(x) + x, _ = self.spatial_linear2(x) + x = self.spatial_norm(x) + + return x + + def fwd_placeholder(x, grid_thw, to_tensor=False): + + grid_thw_cpu = grid_thw.cpu().numpy() + grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size** + 2) + + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // ( + self.spatial_conv_size**2) + batch_offset = np.empty(tokens_per_img_or_vid.size, + dtype=tokens_per_img_or_vid.dtype) + batch_offset[0] = 0 + batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] + + slice_offsets = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(0, temporoal_size, 2): + slice_offsets.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, + axis=-1)).to(x.device) + + slice_offsets2 = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(1 if temporoal_size > 1 else 0, + temporoal_size, 2): + slice_offsets2.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets2 = torch.tensor( + np.concatenate(slice_offsets2, axis=-1)).to(x.device) + + x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) + x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) + x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) + return x + + def fwd_temporal(x): + x, _ = self.temporal_linear1(x) + x = self.temporal_gelu(x) + x, _ = self.temporal_linear2(x) + x = self.temporal_norm(x) + return x + + def fwd_mlp(x): + x, _ = self.mlp(x) + x = self.after_norm(x) + return x + + x = fwd_spatial(x) + if self.use_temporal_conv: + x = fwd_placeholder(x, grid_thw) + x = fwd_temporal(x) + x = fwd_mlp(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(use_fast=True, **kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: Optional[Any], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + patch_size = vision_config.patch_size + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * spatial_conv_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + grid_t = max(num_frames // temporal_conv_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (spatial_conv_size**2) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: Optional[Any], + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[Any], + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + image_processor=image_processor, + ) + return num_video_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_image_tokens = self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + return num_image_tokens + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + # If the number of frames is odd, discard one frame. + if num_frames % 2 != 0: + num_frames -= 1 + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 2) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ernie4_5VLMultiModalProcessor( + BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): + + def _pixel_values_norm( + self, + pixel_values: torch.Tensor, + mm_kwargs: object, + ) -> torch.Tensor: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + image_processor = self.info.get_image_processor(**mm_kwargs) + image_mean_tensor = torch.tensor(image_processor.image_mean, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + image_std_tensor = torch.tensor(image_processor.image_std, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + rescale_factor = torch.tensor(image_processor.rescale_factor, + dtype=torch.float32) + patch_size_squared = vision_config.patch_size**2 + + image_mean_tensor = (image_mean_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + image_std_tensor = (image_std_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + + if not image_mean_tensor.is_contiguous(): + image_mean_tensor = image_mean_tensor.contiguous() + if not image_std_tensor.is_contiguous(): + image_std_tensor = image_std_tensor.contiguous() + + pixel_values = (rescale_factor * pixel_values.to(torch.float32) - + image_mean_tensor) / image_std_tensor + pixel_values = pixel_values.to(hf_config.torch_dtype) + return pixel_values + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # when the prompt is not empty but the multimodal data is empty, + # directly invoke the tokenizer. + if "images" not in mm_data and "videos" not in mm_data and prompt != "": + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), + tensor_type="pt") + return tokenizer_output + + if "images" not in mm_data: + mm_data["images"] = [] + if "videos" not in mm_data: + mm_data["videos"] = [] + processor_output = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=[prompt], + images=mm_data["images"], + videos=mm_data["videos"]), + dict(**mm_kwargs, **tok_kwargs), + ) + + # Divide the processor_output into two modalities: image and video. + if processor_output is not None: + pixel_values = processor_output['images'] + if pixel_values is not None: + processor_output['images'] = self._pixel_values_norm( + pixel_values, mm_kwargs) + for key in list(processor_output.keys()): + if processor_output[key] is None: + del processor_output[key] + continue + if key == "grid_thw": + grid_thw = processor_output['grid_thw'] + pixel_values_all = processor_output['images'] + # Identify elements where the first + # dimension is greater than 1 and + # treat them as the video modality + mask = grid_thw[:, 0] > 1 + processor_output["video_grid_thw"] = grid_thw[mask] + processor_output["image_grid_thw"] = grid_thw[~mask] + image_patch_num = processor_output["image_grid_thw"].prod( + dim=1).sum() + processor_output[ + 'pixel_values'] = pixel_values_all[:image_patch_num] + processor_output['pixel_values_videos'] = pixel_values_all[ + image_patch_num:] + del processor_output['images'] + + return processor_output + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + before_placeholder = { + "image": "<|image@placeholder|>", + "video": "<|video@placeholder|>" + } + + after_placeholder = { + # image and video have same placeholder + "image": "<|IMAGE_PLACEHOLDER|>", + "video": "<|IMAGE_PLACEHOLDER|>" + } + + merge_length = hf_processor.spatial_conv_size**2 + + def get_replacement_ernie45vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + if modality == "video": + num_tokens = int(grid_thw.prod( + )) // hf_processor.temporal_conv_size // merge_length + else: + num_tokens = int(grid_thw.prod()) // merge_length + return after_placeholder[modality] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=before_placeholder[modality], + replacement=partial(get_replacement_ernie45vl, + modality=modality), + ) for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + +class Ernie4_5_VLDummyInputsBuilder( + BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + prompt = "" + for i in range(num_images): + prompt += (f"Picture {i+1}:" + "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>") + + for i in range(num_videos): + prompt += (f"Video {i+1}:" + "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>") + return prompt + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos) + } + + +@MULTIMODAL_REGISTRY.register_processor( + Ernie4_5VLMultiModalProcessor, + info=Ernie4_5_VLProcessingInfo, + dummy_inputs=Ernie4_5_VLDummyInputsBuilder) +class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + # model.resampler_model.-> language_model.model.resampler_model. + # language_model.model.resampler_model. -> resampler_model. + "language_model.model.resampler_model.": "resampler_model.", + }, + # resampler_weight_mappings + orig_to_new_substr={ + "spatial_linear.0.": "spatial_linear1.", + "spatial_linear.2.": "spatial_linear2.", + "spatial_linear.3.": "spatial_norm.", + "temporal_linear.0.": "temporal_linear1.", + "temporal_linear.2.": "temporal_linear2.", + "temporal_linear.3.": "temporal_norm.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + if modality.startswith("video"): + return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_model = Ernie4_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = Ernie4_5_VLMoeForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.resampler_model = VariableResolutionResamplerModel( + self.config.pixel_hidden_size, + self.config.hidden_size, + self.config.spatial_conv_size, + self.config.temporal_conv_size, + config=self.config, + prefix=maybe_prefix(prefix, "resampler_model")) + + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + """compute logits""" + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def _vision_forward( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0] + if grid_thw.numel() % 3 != 0: + raise ValueError( + f"grid_thw has {grid_thw.numel()} elements after filtering," + "which is not divisible by 3.") + grid_thw = grid_thw.reshape(-1, 3) + # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]] + grid_thw = F.pad( + torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), + [1, 0, 0, 0], + value=1, + ) + image_features = self.vision_model(pixel_values, grid_thw) + return image_features + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if getattr(self.config, "im_patch_id", None) is not None: + self.visual_token_mask = ( + input_ids == self.config.im_patch_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Ernie4_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Ernie4_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input( + self, + image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type( + self.vision_model.dtype) + image_features = self._vision_forward(pixel_values=pixel_values, + grid_thw=grid_thw) + image_embeds = self.resampler_model(image_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.vision_model.dtype) + video_features = self._vision_forward(pixel_values=pixel_values_videos, + grid_thw=grid_thw) + video_embeds = self.resampler_model(video_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = (grid_thw.prod(-1) // + self.config.temporal_conv_size) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is None: + return inputs_embeds + + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, + multimodal_embeddings, + [self.config.im_patch_id]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + if self.visual_token_mask is not None: + + if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: + padding_len = inputs_embeds.shape[ + 0] - self.visual_token_mask.shape[0] + # right pad False + pad = torch.zeros( + (padding_len, self.visual_token_mask.shape[1]), + dtype=self.visual_token_mask.dtype, + device=self.visual_token_mask.device) + self.visual_token_mask = torch.cat( + [self.visual_token_mask, pad], dim=0) + + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model( + **forward_kwargs, + **kwargs, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..780974c3b758e8187d71b916ed7f42a0be716130 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from itertools import islice +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +# from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .ernie45_moe import Ernie4_5_MoeMLP +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): + pass + + +class Ernie4_5_VLMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + freq_allocation: int = 20, + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0 + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + t_rope = freq_allocation + h_rope = (self.head_dim // 2 - freq_allocation) // 2 + w_rope = (self.head_dim // 2 - freq_allocation) // 2 + + self.rotary_emb = Ernie4_5_VLRotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + dtype=torch.get_default_dtype(), + mrope_section=[h_rope, w_rope, t_rope]) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Attention + attn_output = self.attn(q, k, v) + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class Ernie4_5_VLMoeMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) + > 0) + self.hidden_size = config.hidden_size + + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + + if self.tp_size > max_moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {moe_num_experts}.") + + moe_layer_start_index = config.moe_layer_start_index + text_moe_layer_start_index = moe_layer_start_index[0] + vision_moe_layer_start_index = moe_layer_start_index[1] + moe_layer_end_index = config.moe_layer_end_index + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + text_moe_layer_end_index = moe_layer_end_index[0] + vision_moe_layer_end_index = moe_layer_end_index[1] + + assert config.moe_num_experts[0] == config.moe_num_experts[1] + self.e_score_correction_bias = nn.Parameter( + torch.empty(2, config.moe_num_experts[0])) + + assert text_moe_layer_start_index <= text_moe_layer_end_index + + if layer_idx >= text_moe_layer_start_index and \ + layer_idx <= text_moe_layer_end_index: + self.text_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[0], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.text_experts_gate") + + self.text_experts = FusedMoE( + num_experts=config.moe_num_experts[0], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[0], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[0], + prefix=f"{prefix}.text_experts") + else: + self.text_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + assert vision_moe_layer_start_index <= vision_moe_layer_end_index + if layer_idx >= vision_moe_layer_start_index and \ + layer_idx <= vision_moe_layer_end_index: + self.vision_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[1], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.vision_experts_gate") + + self.vision_experts = FusedMoE( + num_experts=config.moe_num_experts[1], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[1], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[1], + prefix=f"{prefix}.vision_experts") + else: + self.vision_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + if self.has_shared_experts: + intermediate_size = (config.moe_intermediate_size[0] * + config.moe_num_shared_experts) + self.shared_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=self.text_experts. + must_reduce_shared_expert_outputs()) + + def forward( + self, + hidden_states: torch.Tensor, + visual_token_mask: torch.Tensor, + **kwargs: object, + ) -> torch.Tensor: + + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.has_shared_experts: + shared_output = self.shared_experts(hidden_states) + + if visual_token_mask is not None and visual_token_mask.any(): + # assert visual_token_mask.shape[0] != hidden_states.shape[0] + visual_token_mask = visual_token_mask.repeat( + 1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + final_hidden_states = torch.zeros_like(hidden_states) + + text_hidden_states = hidden_states[text_token_mask].reshape( + -1, self.hidden_size) + vision_hidden_states = hidden_states[visual_token_mask].reshape( + -1, self.hidden_size) + + text_router_logits, _ = self.text_experts_gate(text_hidden_states) + final_hidden_states[text_token_mask] = self.text_experts( + hidden_states=text_hidden_states, + router_logits=text_router_logits).flatten() + + vision_router_logits, _ = self.vision_experts_gate( + vision_hidden_states) + final_hidden_states[visual_token_mask] = self.vision_experts( + hidden_states=vision_hidden_states, + router_logits=vision_router_logits).flatten() + else: + # text modal input processing directly + text_router_logits, _ = self.text_experts_gate(hidden_states) + + final_hidden_states = self.text_experts( + hidden_states=hidden_states, router_logits=text_router_logits) + + if self.has_shared_experts and \ + shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = ( + self.text_experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(orig_shape) + + +class Ernie4_5_VLMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + freq_allocation = getattr(config, "freq_allocation", 20) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + + self.self_attn = Ernie4_5_VLMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, 'head_dim', None), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + freq_allocation=freq_allocation, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'use_bias', False), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # MoE + moe_layer_start_index = config.moe_layer_start_index + min_moe_layer_start_index = min(moe_layer_start_index) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + max_moe_layer_end_index = max(moe_layer_end_index) + assert min_moe_layer_start_index <= max_moe_layer_end_index + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + moe_layer_interval = getattr(config, "moe_layer_interval", 1) + use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) + + if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= min_moe_layer_start_index + and layer_idx <= max_moe_layer_end_index): + self.mlp = Ernie4_5_VLMoeMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor], + **kwargs: object, + ) -> torch.Tensor: + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, Ernie4_5_VLMoeMoE): + hidden_states = self.mlp(hidden_states, visual_token_mask, + **kwargs) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# Since Ernie VL distinguishes between text experts and vision experts, +# enabling torch.compile will cause errors. +# @support_torch_compile( +# dynamic_arg_dims={ +# "input_ids": 0, +# "positions": -1, +# "intermediate_tensors": 0, +# "inputs_embeds": 0, +# "visual_token_mask": 0, +# }) +class Ernie4_5_VLMoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.im_patch_id = config.im_patch_id + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Ernie4_5_VLMoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual, + visual_token_mask, **kwargs) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +# only used as text backbone for ernie4.5-vl +class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=max(self.config.moe_num_experts)) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + loaded_params.add("lm_head.weight") + continue + # MTP will be supported soon. + if "mtp" in name or \ + "vision_model" in name or \ + "resampler_model" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Distinguish between vision experts and text experts + if "mlp.experts" in name: + moe_offset = int(name.split(".")[-3]) + vision_expert_start_idx = self.config.moe_num_experts[0] + is_text_expert = \ + moe_offset <= vision_expert_start_idx - 1 + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace( + f".experts.{moe_offset}", + f".vision_experts.{moe_offset-vision_expert_start_idx}" + ) + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + # Distinguish between vision experts and text experts + moe_offset = int(name.split(".")[-3]) + is_text_expert = \ + moe_offset <= self.config.moe_num_experts[0] - 1 + + name = name.replace(weight_name, param_name) + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace(".experts.", ".vision_experts.") + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Distinguish between vision expert gate + # and text expert gate + if name.endswith("mlp.gate.weight"): + name = name.replace("gate.weight", + "text_experts_gate.weight") + loaded_weight = loaded_weight.T + elif name.endswith("mlp.gate.weight_1"): + name = name.replace("gate.weight_1", + "vision_experts_gate.weight") + loaded_weight = loaded_weight.T + + if "e_score_correction_bias" in name: + name = name.replace(".moe_statics.", ".") + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py new file mode 100644 index 0000000000000000000000000000000000000000..90a1267b28f0a6169a8aff341d4d3a010140c642 --- /dev/null +++ b/vllm/model_executor/models/ernie_mtp.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Ernie-MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .llama import LlamaDecoderLayer +from .utils import is_pp_missing_parameter, maybe_prefix + + +class ErnieMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.mtp_emb_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_hidden_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, + prefix) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + + inputs_embeds = self.mtp_emb_norm(inputs_embeds) + previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states) + + hidden_states = self.mtp_linear_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + + return hidden_states + + +class ErnieMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + ErnieMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + inputs_embeds, + positions, + previous_hidden_states, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + lm_head: ParallelLMHead, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(lm_head, hidden_states, + sampling_metadata) + return logits + + +class ErnieMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config = vllm_config.model_config.hf_config + self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.sampler = get_sampler() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + assert spec_step_idx == 0, "ernie_mtp only support predict one token" + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, self.lm_head, + sampling_metadata, spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + continue + if "rotary_emb.inv_freq" in name: + continue + if "mtp" in name: + name = self._rewrite_spec_layer_name(self.config, name) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if "mtp" not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if "mtp_" not in name and ("embed_tokens" not in name + and "lm_head" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, config: PretrainedConfig, + name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + """ + spec_layer_weight_names = [ + "embed_tokens", "mtp_emb_norm", "mtp_hidden_norm", + "mtp_linear_proj" + ] + layer_idx = config.num_hidden_layers + for weight_name in spec_layer_weight_names: + if weight_name in name: + name = name.replace( + f"model.{weight_name}.0.", + f"model.layers.{layer_idx}.{weight_name}.") + return name + name = name.replace("model.mtp_block.0.", + f"model.layers.{layer_idx}.mtp_block.") + return name diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 8052b6bb82348c06b27ada3388123dc16035d3ad..942db0143a457733377d21dc3785e7926989661f 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -26,6 +26,7 @@ """Inference-only Exaone model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -371,7 +372,7 @@ class ExaoneModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 827e9014184b5984c7a8ff21b5e112ed079d9f6b..971fcbd2aa2754cc2f182a8f83f6eef118dcc3f1 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -22,6 +22,7 @@ """Inference-only Exaone model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -354,7 +355,7 @@ class Exaone4Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 968bc7d6721db430b69d083f4e86cee6309c7ee6..327c9b5ccc93ac52bee19588054c8f34820bfa82 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -22,6 +22,7 @@ import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import os @@ -412,7 +413,7 @@ class FalconModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 56e456c2f1f2a0b96ffaa6b441dc7f2cae8e3353..d0881231fb1e77de9c131e88c42d11daf22acfc9 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -21,7 +21,7 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, @@ -647,7 +647,8 @@ class Florence2LanguageModel(nn.Module): encoder_hidden_states = None - if inputs_embeds is not None or encoder_input_ids.numel() > 0: + if ((inputs_embeds is not None and inputs_embeds.numel() > 0) + or encoder_input_ids.numel() > 0): # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, @@ -681,6 +682,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): self.lm_head = BartParallelLMHead(self.vocab_size, config.d_model, embed_scale=embed_scale) + if self.config.tie_word_embeddings: + self.lm_head.tie_weights(self.model.shared) self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) @@ -749,7 +752,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): else: if "final_logits_bias" in name: continue - if self.config.tie_word_embeddings and "embed_tokens" in name: + if self.config.tie_word_embeddings and ("embed_tokens" in name + or "lm_head" in name): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", @@ -860,7 +864,7 @@ class Florence2MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() pad_token_id = hf_config.pad_token_id diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index b61e0361fe8c3ee440b79cd4b55d3778ca8be38c..90af859ab92ec8c215e1cffab55f694dfecfd73f 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -32,7 +32,7 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -226,7 +226,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 59c3102add4c7b9cb58e1a2c40692918af8685b1..12eb27503870c20dc09d462975742f446f48035e 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -18,6 +18,7 @@ """Inference-only Gemma model compatible with HuggingFace weights.""" from collections.abc import Iterable from functools import cache +from itertools import islice from typing import Optional, Union import torch @@ -308,7 +309,7 @@ class GemmaModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 8cfe92c64540f2d45c9e9fd78e84788fc8300ebb..0bdb6c6bf7ae990aad8950680c6f8c09636478ab 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -292,7 +293,7 @@ class Gemma2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index b762be3c5292511f9fb57df9c53e2a62bca4d693..410c715d5241b717d4699322b21522f42ae02524 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -398,7 +399,7 @@ class Gemma3Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 9871b11b37991f87dad300752dd636663897519d..f3dc7dde46bdf568fa97d7bf9299fd1e5957cecc 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -17,16 +17,17 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -311,7 +312,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.boi_token @@ -337,14 +338,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -373,13 +370,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -404,8 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index a0c3bb50070b3a45f2fac72992a5a86472ca782c..d59dde1560aea26f0747db7aaf3cf74d7a9b1ca6 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -24,16 +24,17 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, MultiModalDataParser) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -209,7 +210,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -254,14 +255,10 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -290,13 +287,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -321,8 +317,8 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 88c53c8363275ee6d6808c4dfa77d71023ce30a9..662728e6b13934f19c8883307294d2c685e27cac 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -59,7 +59,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -74,7 +74,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision +from .qwen2_vl import (_create_qwen2vl_field_factory, + apply_rotary_pos_emb_vision) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -126,7 +127,7 @@ class Glm4vVideoPixelInputs(TensorSchema): - ctpp: Number of channels * temporal_patch_size * patch_size * patch_size - f: Number of frames - - g: Grid dimensions (3 for grid_t which is usually 1 for processed + - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ type: Literal["pixel_values_videos"] = "pixel_values_videos" @@ -141,7 +142,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): - p: Number of video patches across all frames - h: Hidden size (must match language model backbone) - f: Number of frames - - g: Grid dimensions (3 for grid_t which is usually 1 for processed + - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ type: Literal["video_embeds"] = "video_embeds" @@ -234,7 +235,8 @@ class Glm4vVisionAttention(nn.Module): total_num_kv_heads=num_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.qkv", + # Change qkv prefix to align with GLM-4.5V-FP8 quantization config + prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", ) self.proj = RowParallelLinear( input_size=projection_size, @@ -1152,13 +1154,15 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -1175,14 +1179,16 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): merge_length = image_processor.merge_size**2 def get_image_replacement_glm4v(item_idx: int): - grid_thw = out_mm_kwargs["image_grid_thw"][item_idx] + out_item = out_mm_kwargs["image"][item_idx] + grid_thw = out_item["image_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length return [hf_processor.image_token_id] * num_tokens def get_video_replacement_glm4v(item_idx: int): - grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + out_item = out_mm_kwargs["video"][item_idx] + grid_thw = out_item["video_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) video, metadata = mm_items["video"][item_idx] diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index aff491f9596c358369ea2bd714599d2b200af938..06ed453ec29f9325b399d686f4bb9ef388167127 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -24,6 +24,7 @@ """Inference-only GLM-4.5 model compatible with HuggingFace weights.""" import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -131,10 +132,10 @@ class Glm4MoE(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) @@ -158,6 +159,7 @@ class Glm4MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func="sigmoid", + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) @@ -440,8 +442,7 @@ class Glm4MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 1751fccd08b0623161379d46142db87331d1baf2..bf33575859aeae620228aac149007c65c9092068 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -503,7 +503,7 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 98d76337395b9f9dda2050d255978c3f68dd1f9b..4446b5ab181c1c99d390ea89a35d8fae14e38241 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -20,6 +20,7 @@ # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -228,7 +229,7 @@ class GPT2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 036ded530f97df8e6a241508ac8bbffb2f580779..d5c2604145eed5be9000c65b60a6c1fa70115729 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -21,6 +21,7 @@ # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -246,7 +247,7 @@ class GPTBigCodeModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index bd162a5e57bc18ab84cde42cec832b5def14b465..584c7f5d8a2d162d86a4ff2cfe4c1662f4a4e6f6 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -19,6 +19,7 @@ # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -223,7 +224,7 @@ class GPTJModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -336,4 +337,4 @@ class GPTJForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 0f46d75790fa3ea68833cbec68e3bfe2d37c4740..5ce15116d570ae0e0d556f76053eecfdc77240d8 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -19,6 +19,7 @@ # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import os @@ -242,7 +243,7 @@ class GPTNeoXModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 2f5d9ddd9054f33a3d6039a30a6df48ef3ad6600..e0b4df77287571db6af6bab544adc0ec40cfc44b 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -11,7 +11,8 @@ from transformers import GptOssConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -27,7 +28,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .utils import extract_layer_index, maybe_prefix +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class OAIAttention(nn.Module): @@ -74,8 +79,6 @@ class OAIAttention(nn.Module): dtype=torch.bfloat16, requires_grad=False)) - self.norm = RMSNorm(config.hidden_size, eps=1e-5) - self.q_size = self.num_attention_heads * self.head_dim // tp_size self.kv_size = self.num_key_value_heads * self.head_dim // tp_size self.scaling = self.head_dim**-0.5 @@ -118,16 +121,13 @@ class OAIAttention(nn.Module): def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - t = self.norm(hidden_states) - - qkv, _ = self.qkv(t) + qkv, _ = self.qkv(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) v = v.contiguous() attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) - - return output + hidden_states + return output class MLPBlock(torch.nn.Module): @@ -144,7 +144,6 @@ class MLPBlock(torch.nn.Module): self.num_experts = config.num_local_experts self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.norm = RMSNorm(config.hidden_size, eps=1e-5) self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts, dtype=torch.bfloat16) @@ -162,10 +161,9 @@ class MLPBlock(torch.nn.Module): activation="swigluoai") def forward(self, x: torch.Tensor) -> torch.Tensor: - t = self.norm(x) - g = self.router(t) - t = self.experts(hidden_states=t, router_logits=g) - return x + t + g = self.router(x) + x = self.experts(hidden_states=x, router_logits=g) + return x class TransformerBlock(torch.nn.Module): @@ -173,22 +171,41 @@ class TransformerBlock(torch.nn.Module): def __init__( self, config: GptOssConfig, + cache_config: CacheConfig, quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.attn = OAIAttention(config, + prefix=f"{prefix}.attn", + cache_config=cache_config) self.mlp = MLPBlock(config, self.layer_idx, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) - def forward(self, hidden_states: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - attn_output = self.attn(hidden_states, positions) - output = self.mlp(attn_output) - return output + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.attn(hidden_states, positions) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + output = self.mlp(hidden_states) + return output, residual @support_torch_compile @@ -202,87 +219,82 @@ class GptOssModel(nn.Module): ): super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config self.quant_config = vllm_config.quant_config + self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, ) - self.layers = torch.nn.ModuleList([ - TransformerBlock( + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: TransformerBlock( self.config, + cache_config=self.cache_config, quant_config=self.quant_config, - prefix=maybe_prefix(prefix, f"block.{layer_idx}"), - ) for layer_idx in range(self.config.num_hidden_layers) - ]) + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) - def forward(self, input_ids: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - x = self.embedding(input_ids) - for layer in self.layers: - x = layer(x, positions) - x = self.norm(x) - return x - - -class GptOssForCausalLM(nn.Module): + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(input_ids) - def __init__( + def forward( self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__() - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config.hf_config - self.model = GptOssModel( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - ) - self.lm_head = ParallelLMHead( - self.model_config.vocab_size, - self.model_config.hidden_size, - ) - self.logits_processor = LogitsProcessor(self.model_config.vocab_size) - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: - assert intermediate_tensors is None - assert inputs_embeds is None - return self.model(input_ids, positions) + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + x = inputs_embeds + else: + x = self.get_input_embeddings(input_ids) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + residual = None + else: + assert intermediate_tensors is not None + x = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + x, residual = layer(x, positions, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": x, + "residual": residual + }) + x, _ = self.norm(x, residual) + return x def _load_weights_mxfp4( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - + self, + ep_rank_end: int, + ep_rank_start: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + mxfp4_block = 32 + use_ep = self.parallel_config.enable_expert_parallel + num_experts = self.config.num_local_experts tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size + + intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) @@ -294,33 +306,16 @@ class GptOssForCausalLM(nn.Module): tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - for name, weight in weights: + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # FIXME(woosuk): Remove this after testing. weight = weight.cuda() - if "gate_up_proj_blocks" in name: - # Handle MLP gate and up projection weights - new_name = name.replace("gate_up_proj_blocks", "w13_weight") - - # flat weight from (E, 2 * N, block_size, entry_per_block) - # to (E, 2 * N, -1), shouldn't trigger copy for contiguous - weight = weight.view(num_experts, 2 * intermediate_size, - -1).contiguous() - - # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly + if ".w13_weight_scale" in name: + # Handle MLP gate and up projection weights scale if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: @@ -328,43 +323,44 @@ class GptOssForCausalLM(nn.Module): 2 * tp_rank_start:2 * tp_rank_end, ...] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_blocks" in name: + loaded_params.add(name) + continue + elif ".w2_weight_scale" in name: # Handle MLP down projection weights - new_name = name.replace("down_proj_blocks", "w2_weight") - # same flatten here, but since 2 mx4 value are packed in 1 - # uint8, divide by 2 - weight = weight.view(num_experts, -1, - intermediate_size // 2).contiguous() if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[..., - tp_rank_start // 2:tp_rank_end // 2] + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) + loaded_params.add(name) + continue + elif ".w13_weight" in name: + # Handle MLP gate and up projection weights + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() - elif "gate_up_proj_scales" in name: - # Handle MLP gate and up projection weights scale - new_name = name.replace("gate_up_proj_scales", - "w13_weight_scale") + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: @@ -372,39 +368,40 @@ class GptOssForCausalLM(nn.Module): 2 * tp_rank_start:2 * tp_rank_end, ...] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_scales" in name: + loaded_params.add(name) + continue + elif ".w2_weight" in name: # Handle MLP down projection weights - new_name = name.replace("down_proj_scales", "w2_weight_scale") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[..., tp_rank_start // - mxfp4_block:tp_rank_end // - mxfp4_block] + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) - elif "gate_up_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w13_bias" in name: # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] @@ -412,20 +409,19 @@ class GptOssForCausalLM(nn.Module): narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end] - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, narrow_weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w2_bias" in name: # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - param = params_dict[new_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if use_ep: @@ -436,87 +432,73 @@ class GptOssForCausalLM(nn.Module): weight.zero_() weight_loader(param, weight, - weight_name=new_name, + weight_name=name, shard_id=None, expert_id=None) - loaded_params.add(new_name) + loaded_params.add(name) + continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break else: # Handle all other weights with potential renaming - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: + if name not in params_dict: continue - param = params_dict[renamed_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) - loaded_params.add(renamed_name) - + loaded_params.add(name) return loaded_params def _load_weights_other( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + use_ep = self.parallel_config.enable_expert_parallel + tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size + intermediate_size = self.config.intermediate_size per_rank_intermediate_size = cdiv(intermediate_size, tp_size) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - for name, weight in weights: - if ".experts.gate_up_proj" in name and "bias" not in name: - # Handle MLP gate and up projection weights - new_name = name.replace(".experts.gate_up_proj", - ".experts.w13_weight") + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ".w13_weight" in name: + # Handle MLP gate and up projection weights # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: @@ -524,30 +506,25 @@ class GptOssForCausalLM(nn.Module): 2 * tp_rank_start:2 * tp_rank_end] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] + param = params_dict[name] param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif ".experts.down_proj" in name and "bias" not in name: + loaded_params.add(name) + continue + elif ".w2_weight" in name: # Handle MLP down projection weights - new_name = name.replace(".experts.down_proj", - ".experts.w2_weight") - if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] + param = params_dict[name] param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "gate_up_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w13_bias" in name: # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] @@ -555,60 +532,156 @@ class GptOssForCausalLM(nn.Module): narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end] - param = params_dict[new_name] - + param = params_dict[name] param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: + loaded_params.add(name) + continue + elif ".w2_bias" in name: # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() - param = params_dict[new_name] + param = params_dict[name] param.copy_(weight) - loaded_params.add(new_name) + loaded_params.add(name) + continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break else: # Handle all other weights with potential renaming - - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: + if name not in params_dict: continue - param = params_dict[renamed_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) - loaded_params.add(renamed_name) - + loaded_params.add(name) return loaded_params def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - quant_method = (self.model_config.quantization_config['quant_method'] - if hasattr(self.model_config, "quantization_config") - else None) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv", ".q_proj", "q"), + (".qkv", ".k_proj", "k"), + (".qkv", ".v_proj", "v"), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # Attention heads per rank + heads_per_rank = self.config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + quant_method = (self.config.quantization_config['quant_method'] if + hasattr(self.config, "quantization_config") else None) if quant_method == "mxfp4": - return self._load_weights_mxfp4(weights) + return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) else: - return self._load_weights_other(weights) + return self._load_weights_other(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + + +class GptOssForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".self_attn.": ".attn.", + }, + orig_to_new_suffix={ + ".embed_tokens.weight": ".embedding.weight", + + # MoE MXFP4 weights + ".gate_up_proj_blocks": ".w13_weight", + ".down_proj_blocks": ".w2_weight", + ".gate_up_proj_scales": ".w13_weight_scale", + ".down_proj_scales": ".w2_weight_scale", + + # MoE other weights + ".gate_up_proj": ".w13_weight", + ".down_proj": ".w2_weight", + + # MoE Bias + ".gate_up_proj_bias": ".w13_bias", + ".down_proj_bias": ".w2_bias", + }, + ) + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + + self.model = GptOssModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + ) + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 507a9206c42814c53b53657663c3808bbd993bef..f8ba0229210a9431252506300affe1853555d9dc 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -316,7 +317,7 @@ class GraniteModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index c9e3b74e7c3c4718313562c0f61982973afe27dd..221023f1fb657dbadcec024251d51b29d3776e59 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -40,7 +40,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -118,7 +118,7 @@ class GraniteSpeechMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -549,7 +549,7 @@ class GraniteSpeechForConditionalGeneration( raise ValueError("Only audio modality is supported") - def __init__(self, *, vllm_config: VllmConfig, prefix: str): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 7d31854dce8d858452ff792c56ae9eb155358cae..07ad75bcf16653b6a220ca2d247f009cec9aed8e 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only GraniteMoe model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch @@ -303,7 +304,7 @@ class GraniteMoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index f451e65338b78d9fd4da5d45ae8e2c72e1f86035..79c6d8146ba9c390b6da9807a26f3dd22e9748c2 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -397,8 +397,7 @@ class GraniteMoeHybridModel(nn.Module): residual = intermediate_tensors["residual"] num_attn = 0 - for i in range(len(self.layers)): - layer = self.layers[i] + for i, layer in enumerate(self.layers): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): num_attn += 1 diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 1e2e8544179c7404e3132e8e11cc382c6e59d8a4..0b568a4b22685cb045c3f1e6beeb778c1b5f419e 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -6,6 +6,7 @@ The architecture is the same as granitemoe but with the addition of shared experts. """ from collections.abc import Iterable +from itertools import islice from typing import Optional import torch @@ -200,8 +201,7 @@ class GraniteMoeSharedModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 9e7490e3c4f07dddbd229db757438ad9893185ad..a7b324f0a5b4c3eeb9b46fd8b629036a54c20eb4 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -15,12 +15,12 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, build_output, get_prompt_lens, get_prompt_token_ids) from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import SupportsV0Only +from .interfaces_base import default_pooling_type logger = init_logger(__name__) @@ -215,7 +215,8 @@ class GritLMPooler(Pooler): return build_output(pooled_data) -class GritLM(LlamaForCausalLM, SupportsV0Only): +@default_pooling_type("MEAN") +class GritLM(LlamaForCausalLM): """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. The class inherits from LlamaForCausalLM and provides a custom pooling @@ -241,7 +242,6 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): prefix: str = "", **kwargs, ) -> None: - # Use full attention for pooling (this is why V1 is not supported yet) if vllm_config.model_config.runner_type == "pooling": hf_config = vllm_config.model_config.hf_config hf_config.is_causal = False diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 3659249cd8bd628f19b6189a5c9e0acf9aca8376..a591134383371a364a4a3bdbea82d5824edca830 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -23,6 +23,7 @@ # limitations under the License. """Inference-only Grok1 model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -347,8 +348,7 @@ class Grok1Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index c3e4f81597adb6ea6a09e40e5e4c8a4c7c6e0c6a..306775af68065429eee51e3928612013f6f5b8d8 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -17,11 +17,12 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import (MultiModalProcessingInfo, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel @@ -425,18 +426,19 @@ class H2OVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -477,9 +479,8 @@ class H2OVLMultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only @@ -490,7 +491,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return super()._cached_apply_hf_processor( @@ -498,7 +499,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index e5c94c7f3a7068fa15a9c02ce7e2b137182f3674..53f0585541b1c8809c95d7bccb128428794d7f9f 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -53,6 +54,21 @@ IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" +# Based on combine_frames_into_images in +# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py +def get_num_combined_frames( + num_frames: int, + max_grid_shape: tuple[int, int] = (3, 3), +) -> int: + max_num_grids = max_grid_shape[0] * max_grid_shape[1] + + # Calculate the number of canvases needed. + num_canvases = num_frames // max_num_grids + leftover_frames = num_frames % max_num_grids + + return num_canvases + (leftover_frames > 0) + + class HCXVisionMultimodalPixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values_images: list[torch.Tensor] @@ -172,23 +188,20 @@ class HCXVisionMultiModalProcessor( def replace_multimodal_token( token_ids: torch.Tensor, target_token: int, - repeats: list, + repeats: list[int], ): - output = list() + output = list[int]() _repeats_idx = 0 for token_id in token_ids: if token_id == target_token: - output += [ - token_id.item(), - ] * repeats[_repeats_idx] + output += [token_id.item()] * repeats[_repeats_idx] _repeats_idx += 1 else: - output += [ - token_id.item(), - ] + output += [token_id.item()] + return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", list())): + for video_idx, video_arr in enumerate(mm_data.get("videos", [])): if video_arr.dtype == np.uint8: continue mm_data["videos"][video_idx] = video_arr.astype(np.uint8) @@ -205,88 +218,68 @@ class HCXVisionMultiModalProcessor( if len(mm_data) > 0: # batchify input as a single item images = mm_data.get("images", None) - num_images = 0 - if images is not None: - num_images = len(images) - images = [ - images, - ] # batchify - - videos = mm_data.get("videos", - None) # list of video in single conversation - num_videos = 0 - if videos is not None: - num_videos = len(videos) - videos = [ - videos, - ] # batchify + batched_images = None if images is None else [images] + + # list of video in single conversation + videos = mm_data.get("videos", None) + batched_videos = None if videos is None else [videos] _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=images, - videos=videos, + images=batched_images, + videos=batched_videos, ), ) # mm-only for k, v in _processed_outputs.items(): - if len(v) < 1: - continue - elif k.endswith("_images"): - # list of list of 4D tensor -> list of 4D tensor + if isinstance(v, list) and len(v) > 0: + assert len(v) == 1 _processed_outputs[k] = v[0] - elif k.endswith("_videos"): - # list of list of 4D tensor -> list of 4D tensor - v = v[0] - if k == "pixel_values_videos": - v = torch.cat(v, dim=0) - _c, _w, _h = v.shape[-3:] - v = v.reshape(num_videos, -1, _c, _w, _h) - v = list(torch.unbind(v, dim=0)) - _processed_outputs[k] = v - - if num_images > 0: + + if images: tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) processed_outputs["input_ids"] = torch.stack([ replace_multimodal_token( token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - IMAGE_TOKEN), + target_token=image_token_id, repeats=_processed_outputs[ "vision_query_lengths_images"], ) for _input_ids in processed_outputs["input_ids"] ], dim=0) - if num_videos > 0: + if videos: + _num_per_videos = [ + get_num_combined_frames(len(video)) for video in videos + ] + _processed_outputs["pixel_values_videos"] = [ + _processed_outputs["pixel_values_videos"] + [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] + for _i in range(len(videos)) + ] + _processed_outputs["vision_query_lengths_videos"] = [ + _processed_outputs["vision_query_lengths_videos"] + [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] + for _i in range(len(videos)) + ] + tokenizer = self.info.get_tokenizer() + video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) processed_outputs["input_ids"] = torch.stack([ replace_multimodal_token( token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - VIDEO_TOKEN), - repeats=_processed_outputs[ - "vision_query_lengths_videos"], + target_token=video_token_id, + repeats=[ + sum(lens) for lens in + _processed_outputs["vision_query_lengths_videos"] + ], ) for _input_ids in processed_outputs["input_ids"] ], dim=0) - _ratios = [ - len(_pixel_values) for _pixel_values in - _processed_outputs["pixel_values_videos"] - ] - _num_per_videos = [ - int(_e / sum(_ratios) * - len(_processed_outputs["vision_query_lengths_videos"])) - for _e in _ratios - ] - _processed_outputs["vision_query_lengths_videos"] = [ - _processed_outputs["vision_query_lengths_videos"] - [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(0, num_videos) - ] - processed_outputs.update(_processed_outputs) return processed_outputs @@ -295,7 +288,7 @@ class HCXVisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() placeholder = { @@ -306,21 +299,22 @@ class HCXVisionMultiModalProcessor( def get_replacement_hyperclovax( item_idx: int, modality: str, - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ): - num_tokens = None + out_item = out_mm_kwargs[modality][item_idx] + if modality == "image": + lens = out_item["vision_query_lengths_images"].data num_tokens = self.info.get_num_image_tokens( - vision_query_length=out_mm_kwargs[ - "vision_query_lengths_images"][item_idx], ) - if modality == "video": + vision_query_length=lens) + elif modality == "video": + lens = out_item["vision_query_lengths_videos"].data num_tokens = self.info.get_num_video_tokens( - vision_query_length=out_mm_kwargs[ - "vision_query_lengths_videos"][item_idx], ) - assert isinstance(num_tokens, int) - return [ - placeholder[modality], - ] * num_tokens + vision_query_length=lens) + else: + raise NotImplementedError(modality) + + return [placeholder[modality]] * num_tokens return [ PromptReplacement( @@ -374,7 +368,7 @@ def _build_hcxvision_hf_processor( info: HCXVisionProcessingInfo, dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, HCXVisionProcessingInfo): return HCXVisionMultiModalProcessor( @@ -936,8 +930,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): target_group_size = 0 elif video_group_size < target_group_size: - raise RuntimeError(f"video_group_size < target_group_size!! \ - [{video_group_size} < {target_group_size}]") + raise RuntimeError( + f"{video_group_size=} < {target_group_size=}") assert len(target_features ) == 0, f"target_features is not empty!! {target_features}" @@ -1121,9 +1115,8 @@ def reshape_and_unpad_image_features( base_image_feature = image_feature[0] image_feature = image_feature[1:] - assert (height * width == base_image_feature.shape[0] - ), f"height: {height}, width: {width}, \ - base_image_feature.shape[0]: {base_image_feature.shape[0]}" + assert height * width == base_image_feature.shape[0], ( + f"{height=} * {width=} != {base_image_feature.shape[0]=}") num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_size, possible_resolutions, grid_size) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 9e27200fb1c8982da8c1ab06e758684458a8a847..0ca2e9e4bb688d0d65895dfee78191a4f9a2b6c2 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -27,13 +27,15 @@ from transformers.models.idefics2.configuration_idefics2 import ( Idefics2Config, Idefics2VisionConfig) from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): @@ -106,7 +108,7 @@ class Idefics2VisionEmbeddings(nn.Module): bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) - embeddings = embeddings + self.position_embedding(position_ids) + embeddings += self.position_embedding(position_ids) return embeddings @@ -118,6 +120,7 @@ class Idefics2VisionAttention(nn.Module): config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config @@ -130,22 +133,43 @@ class Idefics2VisionAttention(nn.Module): f" {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + assert self.num_heads % tp_size == 0 + self.num_heads_per_partition = self.num_heads // tp_size + + if use_data_parallel: + self.q_size = self.num_heads * self.head_dim + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + 3 * self.q_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = ReplicatedLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -169,18 +193,23 @@ class Idefics2VisionMLP(nn.Module): config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear( + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.fc2 = RowParallelLinear( + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2( config.intermediate_size, config.hidden_size, bias=True, @@ -202,17 +231,21 @@ class Idefics2EncoderLayer(nn.Module): config: Idefics2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Idefics2VisionAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Idefics2VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Idefics2VisionMLP(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -229,11 +262,11 @@ class Idefics2EncoderLayer(nn.Module): residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn(hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual return hidden_states @@ -254,6 +287,7 @@ class Idefics2Encoder(nn.Module): *, num_hidden_layers_override: Optional[int] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -267,7 +301,8 @@ class Idefics2Encoder(nn.Module): self.layers = nn.ModuleList([ Idefics2EncoderLayer(config, quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(num_hidden_layers) ]) @@ -301,17 +336,20 @@ class Idefics2VisionTransformer(nn.Module): num_hidden_layers_override: Optional[int] = None, require_post_norm: bool = True, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() embed_dim = config.hidden_size self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = Idefics2VisionEmbeddings(config) self.encoder = Idefics2Encoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, - prefix=f"{prefix}.encoder") + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: @@ -340,10 +378,38 @@ class Idefics2VisionTransformer(nn.Module): patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ) - encoder_outputs = self.encoder(hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model( + hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state + def _consolidate_qkv_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + qkv_idx_mappings = { + ".self_attn.q_proj": 0, + ".self_attn.k_proj": 1, + ".self_attn.v_proj": 2, + } + qkv_weights = {} + for name, loaded_weight in weights: + for weight_name, idx in qkv_idx_mappings.items(): + if weight_name not in name: + continue + new_name = name.replace(weight_name, ".self_attn.qkv_proj") + if new_name not in qkv_weights: + qkv_weights[new_name] = [None] * 3 + qkv_weights[new_name][idx] = loaded_weight + break + else: + yield name, loaded_weight + for key, weight in qkv_weights.items(): + qkv_weight = torch.cat(weight, dim=0) + yield key, qkv_weight + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -356,6 +422,9 @@ class Idefics2VisionTransformer(nn.Module): loaded_params: set[str] = set() layer_count = len(self.encoder.layers) + if self.use_data_parallel: + weights = self._consolidate_qkv_weights(weights) + for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): @@ -373,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module): continue for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in name or self.use_data_parallel: continue name = name.replace(weight_name, param_name) param = params_dict[name] diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 3c01789b900667a89c39c1fe4cf507473270a085..63307470d959baf0ea7ec9a48c480ac2784ec40c 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -34,7 +34,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageProcessorItems, ImageSize # yapf conflicts with isort for this block # yapf: disable @@ -374,7 +374,7 @@ class Idefics3MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token, _, _ = self.info._get_image_token(hf_processor) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d696ed08d7c4b77a3811ef4be04d817634202223..f27763d354446d3e5859128feb3453b878cac0ed 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_multimodal_raw_input_only: ClassVar[bool] = False + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + """ + + supports_encoder_tp_data: ClassVar[bool] = False + """ + A flag that indicates whether this model supports + `multimodal_config.mm_encoder_tp_mode="data"`. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: """ @@ -141,38 +153,14 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) -@runtime_checkable -class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): - """The interface required for all multi-modal models.""" - - supports_multimodal_raw_input: ClassVar[Literal[True]] = True - """ - A flag that indicates this model supports multi-modal inputs and processes - them in their raw form and not embeddings. - - Note: - There is no need to redefine this flag if this class is in the - MRO of your model class. - """ - - -@overload -def supports_multimodal_raw_input( - model: object) -> TypeIs[SupportsMultiModalWithRawInput]: - ... +def supports_multimodal_raw_input_only( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_multimodal_raw_input_only", False) -@overload -def supports_multimodal_raw_input( - model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: - ... - - -def supports_multimodal_raw_input( - model: Union[type[object], object] -) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], - TypeIs[SupportsMultiModalWithRawInput]]: - return getattr(model, "supports_multimodal_raw_input", False) +def supports_multimodal_encoder_tp_data( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_encoder_tp_data", False) @runtime_checkable @@ -645,20 +633,6 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) -def default_pooling_type(pooling_type: str) -> object: - """Set default_pooling_type decorator. """ - - def func(model: object): - model.default_pooling_type = pooling_type - return model - - return func - - -def get_default_pooling_type(model: Union[type[object], object]) -> str: - return getattr(model, "default_pooling_type", "LAST") - - class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 697fa020deb46d484d0bf6fd7dc4b7fa422c5a6a..19a3ef1a3b800ba34b9085908afb64be0552cc34 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ + default_pooling_type: ClassVar[str] = "LAST" + """ + Indicates the + [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.default_pooling_type][] + decorator to conveniently set this field. + """ + pooler: Pooler """The pooler is only called on TP rank 0.""" @@ -165,3 +176,20 @@ def is_pooling_model( return False return getattr(model, "is_pooling_model", False) + + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def default_pooling_type(pooling_type: str): + """Decorator to set `VllmModelForPooling.default_pooling_type`.""" + + def func(model: _T) -> _T: + model.default_pooling_type = pooling_type # type: ignore + return model + + return func + + +def get_default_pooling_type(model: Union[type[object], object]) -> str: + return getattr(model, "default_pooling_type", "LAST") diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d0c4bf5450d6d9a064890afeb69b9b38dfeddfbb..320e8d9d480c3ffe83d02f025a717cfa01f78e1e 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Any, Optional, Union import torch @@ -31,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -296,7 +298,7 @@ class InternLM2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 4bbb49da0e96f2b15a519ff5fa133ea86d5e2cdd..d41ac2b70bc697d37f513c40ad7fc4032c32bce0 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from itertools import islice from typing import Optional, Union import torch @@ -123,7 +124,7 @@ class InternLM2VEModel(InternLM2Model): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index d952ced2fa69fe1b3099115caa4b5f8e9a28a706..c739e74b058fa59bd57438e002f049ad67bed5de 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -24,7 +24,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -399,7 +399,7 @@ class InternS1MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) img_context_token = hf_processor.image_token @@ -407,15 +407,16 @@ class InternS1MultiModalProcessor( end_image_token = hf_processor.end_image_token video_token = hf_processor.video_token - if "video_num_patches" in out_mm_kwargs: - video_num_patches = out_mm_kwargs["video_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] assert isinstance(video_num_patches, torch.Tensor) video_num_patches = video_num_patches.tolist() else: video_num_patches = [] - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() else: diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 8e766dd4c476892e5c3c00ec8114e60a281a52b6..b09ed7bbe72a39acdd7b638a6ccfe29b586b53ac 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -28,7 +28,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -797,18 +797,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -854,9 +855,13 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): def get_video_token(self) -> Optional[str]: text_model_type = self.get_hf_config().get_text_config().model_type - if text_model_type == "qwen2": - return "<|video_pad|>" - return None + video_token_map = { + "qwen2": "<|video_pad|>", + "qwen3": "<|video_pad|>", + "qwen3_moe": "<|video_pad|>", + "gpt_oss": "<|reserved_200000|>", + } + return video_token_map.get(text_model_type) def get_num_frames_with_most_features( self, @@ -966,15 +971,19 @@ class InternVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - prompt_repl: list[PromptUpdate] = super()._get_prompt_updates( - mm_items, hf_processor_mm_kwargs, out_mm_kwargs) + prompt_repl = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "video_num_patches" in out_mm_kwargs: - video_num_patches = out_mm_kwargs["video_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] assert isinstance(video_num_patches, torch.Tensor) video_num_patches = video_num_patches.tolist() else: @@ -992,12 +1001,15 @@ class InternVLMultiModalProcessor( video_context_token=hf_processor.video_token) if self.info.supports_video: - prompt_repl.append( + prompt_repl = [ + *prompt_repl, PromptReplacement( modality="video", target="<video>", replacement=get_video_replacement_internvl, - )) + ) + ] + return prompt_repl diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index bed4a5dff2efae1a10d25b10c709e2150fa8eccc..91a06dd502474398bb96054e4505400d15658860 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -23,6 +23,7 @@ import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -276,7 +277,7 @@ class JAISModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0b32d6f256590ea1b52920b2cd130377dadcfa38..aebd2cbe2e99997ce71f7578e33278a8a2eb2a5f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jamba model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional import torch @@ -10,6 +11,7 @@ from transformers import JambaConfig from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -154,10 +156,10 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, mamba_cache_params) + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -278,6 +280,7 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class JambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -348,7 +351,7 @@ class JambaModel(nn.Module): kv_cache_index = 0 mamba_cache_index = 0 - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache_index += 1 diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index db9ed5910d78bfbec8e0a5fc24898f8307aaa9f7..af36197d94bf3b5b66bf80a428ec4923c25e41e0 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -33,7 +33,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) @@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, is_pp_missing_parameter, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vit_attn_backend +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -322,7 +323,10 @@ def apply_rotary_pos_emb_flashatt( cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + if not current_platform.is_rocm(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + else: + from flash_attn.layers.rotary import apply_rotary_emb q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) @@ -1192,7 +1196,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -1208,7 +1212,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): merge_length = image_processor.merge_size**2 def get_replacement_keye(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 1c7ddd7df7f8235d8c80d5c98a7a88b733c392e0..a08a9a62a57c5dab041283e9883163d800a3cf1a 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -54,8 +54,7 @@ from transformers import BatchFeature from transformers.activations import GELUActivation from vllm.config import VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -63,13 +62,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model -from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.interfaces import (SupportsMultiModal, + SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -81,7 +81,7 @@ from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .utils import is_pp_missing_parameter, maybe_prefix +from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix # For dummy input only @@ -239,7 +239,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_token_id = self.info.image_token_id @@ -270,7 +270,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): @MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, info=KimiVLProcessingInfo, dummy_inputs=KimiVLDummyInputsBuilder) -class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -304,17 +305,21 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): prefix=maybe_prefix(prefix, "language_model"), ) self.unpadded_vocab_size = config.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.text_config.hidden_size, - org_num_embeddings=self.config.text_config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.config.text_config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + else: + self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.media_placeholder: int = self.config.media_placeholder_token_id - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_world_size = get_tensor_model_parallel_world_size() # ref: qwen2_vl.py def _validate_and_reshape_mm_tensor(self, mm_input: object, diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py new file mode 100644 index 0000000000000000000000000000000000000000..927f78c4e4b45440f02515f6f1af07ebb662d055 --- /dev/null +++ b/vllm/model_executor/models/lfm2.py @@ -0,0 +1,558 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from itertools import islice +from typing import Any, Optional + +import torch +import torch.nn as nn +from transformers import Lfm2Config + +from vllm import envs +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Lfm2MLP(nn.Module): + + def __init__( + self, + dim: int, + ff_dim: int, + multiple_of: int, + auto_adjust_ff_dim: bool, + ffn_dim_multiplier: Optional[float], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if auto_adjust_ff_dim: + ff_dim = int(2 * ff_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + ff_dim = int(ffn_dim_multiplier * ff_dim) + ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class Lfm2Attention(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class Lfm2AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = Lfm2Attention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class Lfm2ShortConvDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.conv_dim, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.conv", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + output = torch.empty_like(hidden_states) + self.conv( + hidden_states, + output, + conv_metadata=None, + ) + hidden_states, residual = self.ffn_norm(output, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Lfm2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + + def get_layer(prefix: str): + layer_idx = extract_layer_index(prefix) + is_attn = self.config.layer_types[layer_idx] == "full_attention" + layer_class = (Lfm2AttentionDecoderLayer + if is_attn else Lfm2ShortConvDecoderLayer) + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, + eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, ...]: + + return MambaStateDtypeCalculator.short_conv_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + """ Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.conv_dim, + conv_kernel=hf_config.conv_L_cache, + use_v1=use_v1, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "Lfm2 currently does not support prefix caching" + assert envs.VLLM_USE_V1, ( + "Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1") + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + + self.model = Lfm2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9a487eda0a6e8f9b770b00321b233296ecbe957a..84748abd80b1740c280100eed13e160c16371bd0 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -34,6 +35,7 @@ import os import re from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -179,7 +181,10 @@ class LlamaAttention(nn.Module): if is_sliding: sliding_window = config.sliding_window - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, @@ -362,7 +367,7 @@ class LlamaModel(nn.Module): else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( @@ -403,7 +408,7 @@ class LlamaModel(nn.Module): aux_hidden_states = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer)): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) @@ -620,10 +625,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 308cb3e85e27bdee7ab27037629e41b9630ca63d..ba08e6f81f7fe7397425a26a210ebee5539c42b1 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -195,7 +195,9 @@ class Llama4Attention(nn.Module): is_neox_style=is_neox_style, ) if not self.nope else None - attn_cls = Attention if self.nope else ChunkedLocalAttention + use_chunked_local_attn = not self.nope and config.attention_chunk_size + attn_cls = (ChunkedLocalAttention + if use_chunked_local_attn else Attention) self.attn = attn_cls( self.num_heads, self.head_dim, @@ -206,7 +208,7 @@ class Llama4Attention(nn.Module): prefix=f"{prefix}.attn", **({ "attention_chunk_size": config.attention_chunk_size - } if not self.nope else {})) + } if use_chunked_local_attn else {})) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4927d6b62c6d8c06b9b4e888b4c73ccb6d845d44..8a847a6180f3a7694e87e5cff0968238d33e8e0f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -250,7 +250,7 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -343,7 +343,7 @@ class PixtralHFMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() @@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( @@ -795,7 +795,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -806,8 +806,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): image_height=-1, ) - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + result = super().apply(prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() @@ -829,26 +832,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): target=[image_token_id] * num_image_tokens, replacement=get_replacement_mantis, ) - ]) + ], mm_item_counts) prompt_ids, prompt, _ = self._apply_prompt_updates( result["prompt_token_ids"], mantis_mm_repls, - mm_item_counts, ) - unbound_orig_repls = self._get_prompt_updates( + orig_repls = self._get_mm_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_and_group_updates(unbound_orig_repls) - - mm_placeholders = self._find_mm_placeholders( - orig_repls, - prompt_ids, - mm_item_counts, - ) + mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) mm_placeholder_ranges = { diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index abc519edadccae41660fcfdf20d9ebc3beb1e5b2..cf9852de633f3f8cabe12228d3b7d1b7ae6c1b2e 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -16,7 +16,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -185,7 +185,7 @@ class LlavaNextVideoMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index ecd24af030a14622c7119a25877da96e6701becb..e4ac0cd919101982b714a41984b54210a7208fbe 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Final, Literal, Optional, Protocol, TypedDict, Union +from typing import Annotated, Final, Literal, Optional, Protocol, Union import torch import torch.nn as nn @@ -11,18 +11,18 @@ from transformers import (BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor) from transformers.models.llava_onevision.modeling_llava_onevision import ( get_anyres_image_grid_shape, unpad_image) -from typing_extensions import NotRequired from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, _MAX_FRAMES_PER_VIDEO = 16 -class LlavaOnevisionVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] +class LlavaOnevisionVideoPixelInputs(TensorSchema): """ - Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)` - - Note that `num_videos` may be different for each batch, and 'num_frames' - may be different for each video, in which case the data is passed as a - list instead of a batched tensor. + Dimensions: + - bn: Batch size * number of videos + - f: Number of frames + - c: Number of channels (3) + - h: Height + - w: Width + + Note that `num_videos` may be different for each batch, and 'num_frames' + may be different for each video, in which case the data is passed as a + list instead of a batched tensor. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), + ] -class LlavaOnevisionImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. +class LlavaOnevisionImagePixelInputs(TensorSchema): """ - - image_sizes: NotRequired[torch.Tensor] + Dimensions: + - bn: Batch size * number of images + - np: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ - Shape: `(batch_size * num_images, 2)` + type: Literal["pixel_values"] = "pixel_values" - This should be in `(height, width)` format. - """ + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), + ] + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] -class LlavaOnevisionImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class LlavaOnevisionImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + + data: Annotated[ + torch.Tensor, + TensorShape("bn", "ifs", "hs"), + ] LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, @@ -372,7 +390,7 @@ class LlavaOnevisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_repls = super()._get_prompt_updates( mm_items=mm_items, @@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - expected_dims = (2, ) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - f"The expected shape of image sizes per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _validate_image_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionImagePixelInputs( type="pixel_values", - pixel_values=self._validate_image_pixel_values( - flatten_bn(pixel_values)), - image_sizes=self._validate_image_sizes( - flatten_bn(image_sizes, concat=True)), - ) + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): @@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") - def _validate_video_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[2:]) - - if actual_dims != expected_dims: - expected_expr = ("num_frames", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values in each video frame " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: @@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=flatten_bn(pixel_values_videos), - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f4aaf0c6f467cfca181a681658285201fd24541c..f02499a4f96b58e7faf4c14b284d8e4e5075c96a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -9,6 +9,7 @@ from torch import nn from transformers import MambaConfig from vllm import envs +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -81,10 +82,12 @@ class MambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params) + return output, residual +@support_torch_compile class MambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 3432cf29feac6a50c9cbb425c1fcdb6515f71b02..81b9a125380aa97512af16caba2409b4c9a6ce95 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -164,9 +164,7 @@ class Mamba2Model(nn.Module): # v1 get mamba2_metadata from forward_context mamba2_metadata = None - for i in range(len(self.layers)): - layer = self.layers[i] - + for i, layer in enumerate(self.layers): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 5b497dd9d89f5de1fd578a42a5770b259c6f25a2..ea5292d0df202518e6d89dd9f027fbd0f460f8ad 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -26,6 +26,7 @@ # limitations under the License. """Inference-only MiMo model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -74,7 +75,7 @@ class MiMoModel(Qwen2Model): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index d398a5d12bbcdbbfa4610ce8e9799d52bc64b370..5632f8c8cc4fbed6ae0b1194dfc4c857eab10d80 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -25,6 +25,7 @@ """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -414,7 +415,7 @@ class MiniCPMModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index e1746695bd5dbc58c2cbd5e49acaf1f61beef1da..225668d87facbbe8b155cf9125c166019184966e 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch from torch import nn @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, NestedTensors) from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, @@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, MultiModalDataParser) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, PromptUpdateDetails) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, MiniCPMVDummyInputsBuilder, @@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, CPU_DEVICE = torch.device("cpu") -class MiniCPMOAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - audio_features: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioFeatureInputs(TensorSchema): """ - Shape: `(batch_size * num_audios * num_slices, num_channels, length)` - Slice here means chunk. Audio that is too long will be split into slices, - which is the same as image. - Padding is used therefore `audio_features` is `torch.Tensor`. + Dimensions: + - bns: Batch size * number of audios * number of slices + - bn: Batch size * number of audios + - c: Number of channels + - l: Length + - s: Number of slices """ + type: Literal["audio_features"] = "audio_features" - audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]] + audio_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bns", "c", "l", dynamic_dims={"l"}), + ] + """ + Slice here means chunk. Audio that is too long will be split into slices, + which is the same as image. Padding is used therefore `audio_features` is + `torch.Tensor`. """ - Shape: `(batch_size * num_audios, num_slices)` + audio_feature_lens: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s"), + ] + """ This should be feature length of each audio slice, which equals to `audio_features.shape[-1]` """ -class MiniCPMOAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - audio_embeds: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_audios, num_slices, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - instead of a batched tensor. + Dimensions: + - bn: Batch size * number of audios + - s: Number of slices + - h: Hidden size (must match language model backbone) + Length of each slice may vary, so pass it as a list. """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s", "h", dynamic_dims={"s"}), + ] MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, @@ -316,7 +334,7 @@ class MiniCPMOMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: base_updates = super()._get_prompt_updates( mm_items=mm_items, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 47ce771d8c90196c19c4d7d3edd24be55b7ef124..04176c5589ed6f6d886ee8133ed1790ef69af531 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -27,12 +27,14 @@ import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial +from itertools import chain from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch import torch.types from torch import nn +from torch.nn.init import trunc_normal_ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar @@ -47,10 +49,11 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageProcessorItems, ImageSize, ModalityData, ModalityDataItems, @@ -58,7 +61,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, VideoItem, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + PromptUpdate, PromptUpdateDetails, + ResolvedPromptUpdate, _seq2text) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -217,6 +221,187 @@ class Resampler2_5(BaseResampler): return x +class Resampler4_5(Resampler2_5): + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + max_temporal_size: int = 36000, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + max_size, + quant_config=quant_config, + prefix=prefix) + + trunc_normal_(self.query, std=.02) + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size) + self.apply(self._init_weights) + + def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, + pos: np.ndarray): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + def _set_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu") -> None: + temporal_size = np.arange(max_temporal_size, dtype=np.float32) + pos_embed = torch.from_numpy( + self.get_1d_sincos_pos_embed_from_temporal_size( + self.embed_dim, temporal_size)).float().to(device) + self.register_buffer("temporal_pos_embed", pos_embed, persistent=False) + + def _adjust_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu"): + if max_temporal_size > self.max_temporal_size: + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size, device) + + def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + # temporal_ids for high refresh rate videos + temporal_ids=None + ) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + temporal_pos_emb = False + temporal_ids_flatten = None + if temporal_ids is not None: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten, default=0) + if max_temporal_size > -1: + temporal_pos_emb = True + if max_temporal_size > self.max_temporal_size: + self._adjust_temporal_pos_cache(max_temporal_size, device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros((bs, max_patch_len), + dtype=torch.bool, + device=device) + + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + q = self.ln_q(self.query) # Q * D + + pos_embed_2d = [] + pos_embed_temporal = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i] + if temporal_pos_emb: + if temporal_ids_flatten[i] == -1: + pos_embed_temporal.append( + torch.zeros(self.embed_dim, dtype=dtype, + device=device)) + else: + pos_embed_temporal.append(self.temporal_pos_embed[ + temporal_ids_flatten[i]].to(dtype)) # D + + pos_embed_2d.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( + (tgt_h * tgt_w, -1)).to(dtype)) # patches * D + key_padding_mask[i, patch_len[i]:] = True + + pos_embed_2d = torch.nn.utils.rnn.pad_sequence( + pos_embed_2d, batch_first=True, + padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D + + k = x + v = x + pos_embed_2d + if pos_embed_temporal: + k += torch.stack(pos_embed_temporal, dim=0) + bs = len(temporal_ids) + merge_k = [] + merge_v = [] + merge_key_padding_mask = [] + + start = 0 + for tp in temporal_ids: + end = start + len(tp) + # L * (end-start) * D -> (end-start) * L * D + # -> 1 * L*(end-start) * D + merge_k.append(k[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_v.append(v[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_key_padding_mask.append( + key_padding_mask[start:end, :].reshape(-1, 1)) + + start = end + + k = torch.nn.utils.rnn.pad_sequence(merge_k, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence(merge_v, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + key_padding_mask = torch.nn.utils.rnn.pad_sequence( + merge_key_padding_mask, batch_first=True, + padding_value=True).squeeze(-1) + + out = self.attn( + self._repeat(q, bs), # Q * B * D + k, # L * B * D + L * B * D + v, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: version_float = getattr(config, "version", None) @@ -353,9 +538,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: mm_limits = {"image": None} - if self.get_model_version() == (2, - 6) or self.get_model_version() == (4, - 0): + if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}: mm_limits["video"] = None return mm_limits @@ -636,8 +819,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): out_keys: set[str], ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together - if self.info.get_model_version() == ( - 2, 6) or self.info.get_model_version() == (4, 0): + if self.info.get_model_version() in {(2, 6), (4, 0), (4, 5)}: inputs = super()._call_hf_processor( prompt=prompts, # type: ignore mm_data=mm_data, @@ -694,7 +876,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: placeholders = [("image", self.info.image_pattern), ("video", self.info.video_pattern)] @@ -744,6 +926,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): for modality, pattern in placeholders ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor() + version = self.info.get_model_version() + + text = _seq2text(tokenizer, cached_update.content.full) + prev_item_idx = cached_update.item_idx + + if version == (2, 0) or version == (2, 5): + im_start = image_processor.im_start_token + im_end = image_processor.im_end_token + else: + im_start = image_processor.im_id_start + im_end = image_processor.im_id_end + + new_update = new_update.with_content( + PromptUpdateDetails.select_text( + text.replace( + f"{im_start}{prev_item_idx}{im_end}", + f"{im_start}{new_item_idx}{im_end}", + 1, + ), + "<unk>", + )) + + return new_update + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -758,6 +977,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): instantiated. """ + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -771,6 +992,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config quant_config = vllm_config.quant_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot @@ -1018,6 +1240,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): class MiniCPMV2_0(MiniCPMVBaseModel): + supports_encoder_tp_data = False + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (2, 0) @@ -1132,9 +1356,12 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> nn.Module: - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1222,9 +1449,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> nn.Module: - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1325,9 +1555,12 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): prefix: str = "", ) -> nn.Module: quant_config = self._maybe_ignore_quant_config(quant_config) - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1395,11 +1628,124 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): return loader.load_weights(weights) +class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 5) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return Qwen3ForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler4_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + temporal_ids = data.get('temporal_ids', None) + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + all_temporal_ids = None if temporal_ids is None else flatten_2d_lists( + temporal_ids) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, (2, 6): MiniCPMV2_6, (4, 0): MiniCPMV4_0, + (4, 5): MiniCPMV4_5, } diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 82e96844cd5f666d846d770d96f7435f21572771..ef1fe86c5b5c035c53fee25e37e9ffb4cc9d314b 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,52 +1,48 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import copy -import math from collections.abc import Iterable -from typing import Optional, Union +from itertools import islice +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + pass import regex as re import torch import torch.distributed -import torch.nn.functional as F -from einops import rearrange from torch import nn from transformers import MiniMaxConfig from vllm import envs from vllm.attention import Attention, AttentionMetadata -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context -from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.lightning_attn import ( - lightning_attention, linear_decode_forward_triton) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.linear_attn import ( + MiniMaxText01LinearAttention) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from .interfaces import HasInnerState, IsHybrid from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams @@ -80,121 +76,6 @@ def weight_loader_with_alias(alias: str): return wrapper -class MiniMaxText01RMSNormTP(CustomOp): - name = "MiniMaxText01RMSNormTP" - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.tp_world = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.weight = nn.Parameter(torch.ones(int(hidden_size / - self.tp_world))) - - self.weight.weight_loader = self.weight_loader - self.variance_epsilon = eps - return - - @staticmethod - def weight_loader( - param: nn.Parameter, - loaded_weight: torch.Tensor, - ) -> None: - tp_world = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - shard_size = loaded_weight.shape[0] // tp_world - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - param.data.copy_(loaded_weight[shard]) - return - - def _forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - orig_dtype = x.dtype - x = x.to(torch.float32) - variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) - if self.tp_world > 1: - variance = tensor_model_parallel_all_reduce( - variance) / self.tp_world - x = x * torch.rsqrt(variance + self.variance_epsilon) - - weight = self.weight - if x.size(-1) != self.weight.size(0): - if self.weight.size(0) < x.size(-1): - repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) - full_weight = self.weight.repeat(repeat_count) - weight = full_weight[:x.size(-1)] - else: - weight = self.weight[:x.size(-1)] - - x = x.to(orig_dtype) * weight - return x - - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - assert residual is None, "RMSNorm does not support residual connection." - return self._forward(x) - - -class MiniMaxText01RotaryEmbedding(CustomOp): - name = "MiniMaxText01RotaryEmbedding" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position: int, - base: float, - is_neox_style: bool, - cache_dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position - self.base = base - self.is_neox_style = is_neox_style - self.cache_dtype = cache_dtype - cache = self._compute_cos_sin_cache().to(cache_dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - """Compute the inverse frequency.""" - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) - query_cast = query.to(self.cache_dtype) - key_cast = key.to(self.cache_dtype) - ops.rotary_embedding(positions, query_cast, key_cast, self.head_size, - self.cos_sin_cache, self.is_neox_style) - query = query_cast.to(query.dtype) - key = key_cast.to(key.dtype) - return query, key - - class MiniMaxText01MLP(nn.Module): def __init__( @@ -301,284 +182,6 @@ class MiniMaxText01MoE(nn.Module): return final_hidden -class MiniMaxText01LinearKernel: - - @staticmethod - def jit_linear_forward_prefix(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_caches: torch.Tensor, - slope_rate: torch.Tensor, - block_size: int, - layer_idx: int = None, - **kwargs) -> torch.Tensor: - - slope_rate = slope_rate.to(torch.float32) - should_pad_dim = q.dim() == 3 - if should_pad_dim: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - b, h, n, d = q.shape - e = d - kv_history = kv_caches.reshape(1, h, d, e).contiguous() - output, kv_history = lightning_attention(q, - k, - v, - slope_rate, - block_size=block_size, - kv_history=kv_history) - kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) - assert output.shape[0] == 1, "batch size must be 1" - return rearrange(output.squeeze(0), "h n d -> n (h d)") - - -class MiniMaxText01LinearAttention(nn.Module, MambaBase): - - @property - def mamba_type(self) -> str: - return "linear_attention" - - def get_state_dtype(self) -> tuple[torch.dtype]: - return MambaStateDtypeCalculator.linear_attention_state_dtype( - self.model_config.dtype, - self.cache_config.mamba_cache_dtype, - ) - - def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return MambaStateShapeCalculator.linear_attention_state_shape( - num_heads=self.num_heads, - tp_size=self.tp_size, - head_dim=self.head_dim) - - def __init__( - self, - hidden_size: int, - hidden_inner_size: int, - num_heads: int, - head_dim: int, - max_position: int, - block_size: int, - num_hidden_layer: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_idx: int = 0, - linear_layer_idx: int = 0, - prefix: str = "linear_attn", - ) -> None: - super().__init__() - - self.layer_idx = layer_idx - self.BLOCK = block_size - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = head_dim - self.total_num_heads = num_heads - self.hidden_inner_size = hidden_inner_size - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - assert self.total_num_heads % self.tp_size == 0 - self.tp_heads = self.total_num_heads // self.tp_size - self.qkv_size = self.num_heads * self.head_dim - self.tp_hidden = self.head_dim * self.tp_heads - self.model_config = model_config - self.cache_config = cache_config - self.prefix = prefix - - self.qkv_proj = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size * 3, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.output_gate = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.output_gate", - ) - self.out_proj = RowParallelLinear( - self.hidden_inner_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.norm = MiniMaxText01RMSNormTP( - self.hidden_inner_size, - eps=1e-5, - ) - - slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( - self.num_heads) - if num_hidden_layer <= 1: - self.slope_rate = slope_rate * (1 + 1e-5) - else: - self.slope_rate = slope_rate * (1 - layer_idx / - (num_hidden_layer - 1) + 1e-5) - self.tp_slope = self.slope_rate[self.tp_rank * - self.tp_heads:(self.tp_rank + 1) * - self.tp_heads].contiguous() - - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - @staticmethod - def weight_direct_load(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - return - - @staticmethod - def _build_slope_tensor(n_attention_heads: int): - - def get_slopes(n): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.tensor(get_slopes(n_attention_heads), - dtype=torch.float32).reshape( - n_attention_heads, 1, 1) - return slopes - - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - hidden = [] - for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): - if _prefill_idx >= len(attn_metadata.query_start_loc): - break - if _prefill_idx >= len(state_indices_tensor): - break - # prefills are packed at end of batch in V1 - offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 - _start = attn_metadata.query_start_loc[offset + _prefill_idx] - _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] - slot_id = state_indices_tensor[offset + _prefill_idx] - qs = q[_start:_end].transpose(0, 1).contiguous() - ks = k[_start:_end].transpose(0, 1).contiguous() - vs = v[_start:_end].transpose(0, 1).contiguous() - slice_layer_cache = kv_cache[slot_id, ...] - - out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( - qs, - ks, - vs, - slice_layer_cache, - self.tp_slope, - self.BLOCK, - layer_idx=self.layer_idx) - hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: - hidden_decode = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - if envs.VLLM_USE_V1: - hidden.insert(0, hidden_decode) - else: - hidden.append(hidden_decode) - - if not hidden: - return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) - - hidden = torch.concat(hidden, dim=0).contiguous() - return hidden - - def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - if not envs.VLLM_USE_V1: - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - num_prefills = getattr(attn_metadata, "num_prefills", 0) - slot_id = state_indices_tensor[num_prefills:] - else: - q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[:attn_metadata.num_decodes] - hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, - slot_id, 32) - return hidden - - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - qkv32 = qkv.to(torch.float32) - qkvact = torch.nn.functional.silu(qkv32) - qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) - q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) - kv_cache = self.kv_cache[forward_context.virtual_engine][0] - state_indices_tensor = attn_metadata.state_indices_tensor - - num_prefills = getattr(attn_metadata, "num_prefills", 0) - if num_prefills > 0: - num_decode_tokens = getattr(attn_metadata, - "num_decode_tokens", 0) - for prefill_idx in range(num_prefills): - q_start = attn_metadata.query_start_loc[ - num_decode_tokens + prefill_idx] - q_end = attn_metadata.query_start_loc[num_decode_tokens - + prefill_idx + - 1] - query_len = q_end - q_start - context_len = attn_metadata.seq_lens[ - num_decode_tokens + prefill_idx] - query_len - if context_len == 0: - block_to_clear = state_indices_tensor[ - num_decode_tokens + prefill_idx] - kv_cache[block_to_clear, ...] = 0 - else: - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor - - decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if attn_metadata is None: - hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), - device=q.device, - dtype=q.dtype) - else: - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) - hidden = F.sigmoid(gate) * hidden - hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) - return hidden - - class MiniMaxText01Attention(nn.Module): def __init__( @@ -644,23 +247,23 @@ class MiniMaxText01Attention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.attn", ) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=int(rope_theta), + is_neox_style=True, + dtype=torch.float32, + ) return - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - **kwargs) -> torch.Tensor: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, **kwargs) -> None: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if envs.VLLM_USE_V1: - if attn_metadata is not None: - q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb( - positions, q, k) - else: - q, k = attn_metadata.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + output[:], _ = self.o_proj(attn_output) class MiniMaxText01DecoderLayer(nn.Module): @@ -808,16 +411,15 @@ class MiniMaxText01DecoderLayer(nn.Module): is_warmup: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input - self_attention_output = self.self_attn( + self_attention_output = torch.empty_like(layernorm_output) + self.self_attn( hidden_states=layernorm_output, + output=self_attention_output, positions=positions, kv_caches=kv_caches, - attn_metadata=attn_metadata, ) residual = residual * self.layernorm_attention_alpha @@ -831,8 +433,8 @@ class MiniMaxText01DecoderLayer(nn.Module): if self.expert_num == 1: hidden_states = self.mlp(layernorm_output) else: - moe_hidden_states = self.block_sparse_moe( - copy.deepcopy(layernorm_output)) + moe_layernorm_output = layernorm_output.clone() + moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) @@ -870,18 +472,16 @@ class MiniMaxText01DecoderLayer(nn.Module): return +@support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__( - self, - config: MiniMaxConfig, - model_config: Optional[ModelConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -968,24 +568,6 @@ class MiniMaxText01Model(nn.Module): self.minimax_cache = MinimaxCacheManager( dtype=torch.float32, cache_shape=self.cache_shape) - rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", None) - if head_dim is None: - head_dim = config.hidden_size // config.num_attention_heads - if hasattr(config, "max_model_len") and isinstance( - config.max_model_len, int): - max_position_embeddings = min(config.max_position_embeddings, - config.max_model_len) - self.rotary_emb = MiniMaxText01RotaryEmbedding( - head_dim, - rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") else head_dim, - max_position=max_position_embeddings, - base=int(rope_theta), - is_neox_style=True, - cache_dtype=torch.float32, - ) - norm_kwargs = {} if hasattr(config, "rms_norm_eps"): norm_kwargs["eps"] = config.rms_norm_eps @@ -1035,12 +617,11 @@ class MiniMaxText01Model(nn.Module): attn_metadata = forward_context.attn_metadata if not envs.VLLM_USE_V1 and attn_metadata is None: return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - if not envs.VLLM_USE_V1: + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] ( minimax_cache_tensors, state_indices_tensor, @@ -1067,18 +648,7 @@ class MiniMaxText01Model(nn.Module): minimax_cache_index = 0 - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - if attn_metadata is not None: - # TODO (tdoublep): this whole thing with the rotary_emb is - # weird. we shouldn't be passing it via attn_metadata imo. - if envs.VLLM_USE_V1: - if isinstance(layer.self_attn, MiniMaxText01Attention): - attn_metadata[layer.prefix + - ".attn"].rotary_emb = self.rotary_emb - else: - attn_metadata.rotary_emb = self.rotary_emb - + for layer in islice(self.layers, self.start_layer, self.end_layer): _caches = None if not envs.VLLM_USE_V1 and isinstance( layer.self_attn, MiniMaxText01LinearAttention): @@ -1112,7 +682,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1125,13 +694,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - self.config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 8107c6e8a04a192c3813898843ea9bbb0bccc2ba..cc7db849a28bff4901a3d4271953395f5797f2c2 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import torch import torch.nn as nn from transformers import BatchFeature, PretrainedConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn @@ -17,6 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -29,24 +32,36 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -class MiniMaxVL01ImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class MiniMaxVL01ImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` - - Note that `height` or `width` may be different per batch and image, + Dimensions: + - bn: Batch size * number of images + - np: Number of patches + 1 + - c: Number of channels (3) + - h: Height + - w: Width + + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})] + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + # This should be in `(height, width)` format. -class MiniMaxVL01ImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class MiniMaxVL01ImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs, @@ -141,6 +156,7 @@ class MiniMaxVL01MultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return { "pixel_values": MultiModalFieldConfig.batched("image"), + "image_sizes": MultiModalFieldConfig.batched("image"), "image_embeds": MultiModalFieldConfig.batched("image"), } @@ -239,7 +255,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features = tuple(vision_tower(p) for p in pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -252,6 +268,56 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, json_map_leaves(select_features, image_features), ) + # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 + def pack_image_features(self, image_features: list[torch.Tensor], + image_sizes: torch.Tensor): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = (self.config.vision_config.image_size // + self.config.vision_config.patch_size) + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with " + "the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, + num_patch_width, height, + width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, + 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, + image_sizes[image_idx]) + + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1).to( + image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), + dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, + self.image_newline[None].to(image_feature)), + dim=0) + new_image_features.append(image_feature) + return new_image_features + def _process_image_pixels( self, inputs: MiniMaxVL01ImagePixelInputs, @@ -259,7 +325,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, assert self.vision_tower is not None pixel_values = inputs["pixel_values"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) def _process_image_input( @@ -281,38 +346,31 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) - return image_embeds - - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data + image_sizes = image_input.get("image_sizes") + return self.pack_image_features(image_embeds, image_sizes) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - if pixel_values is not None: + if pixel_values is not None and image_sizes is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + return MiniMaxVL01ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 9e29a96c6e44a8d546d206b243d61ac9b6e8f6e1..08948960b275c0386f3e84b4bc3dbf8d0894057d 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union) import torch @@ -22,16 +22,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -42,16 +43,24 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .vision import get_vision_encoder_info -class Mistral3ImagePixelInputs(TypedDict): - type: Literal["pixel_values_pixtral"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class Mistral3ImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` - - Note that `height` or `width` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" + + # Note that `height` or `width` may be different per batch and image, + # in which case the data is passed as a list instead of a batched tensor. + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), + ] + class Mistral3PatchMerger(nn.Module): """ @@ -265,7 +274,7 @@ class Mistral3MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() @@ -313,7 +322,7 @@ def _build_mistral3_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( @@ -456,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 43c8b18f2ab0133ca8e950b8c1f0efec4817b131..665e0be501ccdb3950cd1097a750e90f2339e89c 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -28,6 +28,7 @@ import re from typing import Iterable, Optional, Union from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -319,7 +320,7 @@ class MixtralModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index c8ad358c622d256af38faf39181e65207a178cc4..692267b4d7271c353218e6227bd93e404a1d9918 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only Mixtral model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import numpy as np @@ -346,7 +347,7 @@ class MixtralModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 30ae3f26c8e45bad441e33b5d2b62bd5d2825e59..f441287a4d08965e61515ef8860771200cebaf74 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -17,7 +17,7 @@ """PyTorch Mllama model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import numpy as np import torch @@ -56,13 +56,15 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalKwargs) + MultiModalFieldConfig, + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only @@ -72,15 +74,30 @@ from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = init_logger(__name__) -class MllamaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: """ - """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" - aspect_ratio_ids: torch.Tensor - """Shape: `(batch_size, max_num_image)`""" - aspect_ratio_mask: torch.Tensor - """Shape: `(batch_size, max_num_image, max_num_tiles)`""" +class MllamaImagePixelInputs(TensorSchema): + """ + Dimensions: + - batch_size: Batch size + - max_num_image: Max number of images + - max_num_chunk: Max number of chunks + - max_num_tiles: Max number of tiles per image + - num_channel: Number of channels + - height: Height + - width: Width + """ + + type: Literal["pixel_values"] = "pixel_values" + + data: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_chunk", + "num_channel", "height", "width")] + + aspect_ratio_ids: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image")] + + aspect_ratio_mask: Annotated[ + torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_tiles")] # TODO: support LlamaImageEmbeddingInputs @@ -167,10 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalEncDecInputs: - mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + mm_inputs = super().apply(prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches @@ -217,7 +237,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] # Set encoder prompt length based on the number of tiles. # This tells the block manager to allocate correct number # of slots for encoder tokens. - num_tiles = mm_inputs["mm_kwargs"]["num_tiles"] + num_tiles = mm_inputs["mm_kwargs"].get_data()["num_tiles"] decode_tiles = num_tiles[num_encode_images:num_images].sum().item() num_tokens = decode_tiles * token_per_chunk mm_inputs["encoder_prompt_token_ids"] = [image_token_id @@ -302,7 +322,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: token_per_chunk = self.info.get_token_per_chunk_from_config() image_token_id = self.info.get_hf_config().image_token_index @@ -1351,7 +1371,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, output_tensor[i, :t.size(0)] = t return output_tensor - def _parse_and_validate_image_input(self, **kwargs: object): + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[MllamaImagePixelInputs]: # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: # - list[torch.Tensor]: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 72b91732e58cb06fd08a534b253a64cc8a3d7365..e56303b6dfeba99d2f4a13cbc181b3b9fd023a63 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -19,7 +19,7 @@ import math from collections.abc import Iterable, Mapping from itertools import tee -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llama4 import Llama4ForCausalLM @@ -60,28 +61,34 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) -class Llama4ImagePatchInputs(TypedDict): - type: Literal["pixel_values"] - flat_data: torch.Tensor +class Llama4ImagePatchInputs(TensorSchema): """ - Shape: - `(batch_size * num_chunks, num_channels, image size, image size)` + Dimensions: + - batch_size: Batch size + - total_num_chunks: Batch size * number of chunks + - num_channels: Number of channels + - image_size: Size of each image """ - patches_per_image: torch.Tensor + + type: Literal["pixel_values"] = "pixel_values" + + flat_data: Annotated[torch.Tensor, + TensorShape("total_num_chunks", "num_channels", + "image_size", "image_size")] + + patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")] """ The number of total patches for each image in the batch. - + This is used to split the embeddings which has the first two dimensions flattened just like `flat_data`. """ - aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] + aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)] """ A list of aspect ratios corresponding to the number of tiles in each dimension that each image in the batch corresponds to. - - Shape: - `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)` + Each aspect ratio is a pair (ratio_h, ratio_w). """ @@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] for (r_h, r_w) in aspect_ratios ] - processed_outputs["aspect_ratios"] = aspect_ratios + processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios) processed_outputs["patches_per_image"] = torch.tensor( patches_per_image) @@ -646,13 +653,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: - assert ( - mm_items.get_count("image", strict=False) == 0 - or "aspect_ratios" in out_mm_kwargs - ), "Transformers expect to include aspect_ratios in out_mm_kwargs" - config = self.info.get_hf_config() vision_config = config.vision_config @@ -662,7 +664,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] img_patch_token = hf_processor.img_patch_token def get_replacement(item_idx: int): - aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] + out_item = out_mm_kwargs["image"][item_idx] + aspect_ratio = out_item["aspect_ratios"].data repl = hf_processor._prompt_split_image( aspect_ratio=aspect_ratio, @@ -720,6 +723,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, "gate_up_proj": ["gate_proj", "up_proj"], } + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -732,8 +737,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config @@ -772,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # TODO: confirm handling for variable lengths flat_pixel_values = flatten_bn(pixel_values, concat=True) patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) - - aspect_ratios = kwargs.pop("aspect_ratios", None) - if not isinstance(aspect_ratios, (torch.Tensor, list)): - raise ValueError("Incorrect type of aspect_ratios. " - f"Got type: {type(aspect_ratios)}") + aspect_ratios = kwargs.pop("aspect_ratios") + if aspect_ratios.ndim == 3: + aspect_ratios = aspect_ratios.squeeze(1) return Llama4ImagePatchInputs( type="pixel_values", diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index c6e84e2d4e040783961350fe4abf414b15b4b5bc..776287589808a969599084203b91fe4eb61c66bd 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -7,7 +7,7 @@ import torch from torch import nn from transformers import ModernBertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -22,11 +22,12 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type from .utils import WeightsMapper, maybe_prefix @@ -104,12 +105,12 @@ class ModernBertAttention(nn.Module): head_size=self.head_dim, dim=self.head_dim, base=rope_theta) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - prefix=f"{layer_id}.attn", - attn_type=AttentionType.ENCODER_ONLY, - per_layer_sliding_window=sliding_window) + self.attn = EncoderOnlyAttention( + self.num_heads, + self.head_dim, + self.scaling, + prefix=f"{layer_id}.attn", + per_layer_sliding_window=sliding_window) self.Wo = RowParallelLinear(config.hidden_size, config.hidden_size, bias=config.attention_bias) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 78dc0dca957f037cfa569dff2af2d8a0bb88a728..b2fc7be1af224f9db193619809b8d39821e0e362 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -5,7 +5,8 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import Optional, TypedDict, Union +from itertools import islice +from typing import Annotated, Optional, Union import numpy as np import torch @@ -42,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -51,6 +52,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -70,23 +72,25 @@ IM_END_TOKEN = "<im_end>" POOLING_SIZE = 2 -class MolmoImageInputs(TypedDict): - images: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`""" - - image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] - """Shape: `(batch_size * num_images, num_crops, num_patch)`""" - - feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] +class MolmoImageInputs(TensorSchema): """ - A boolean mask indicating which image features correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_crops, num_patch)` + Dimensions: + - bn: Batch size * number of images + - nc: Number of crops + - np: Number of patches + - pd: Patch dimension """ + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np", "pd")] + + image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], + TensorShape("bn", "nc", "np")] + + feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np")] + # A boolean mask indicating which image features correspond to patch tokens. - num_crops: torch.Tensor - """Shape: `(batch_size * num_images)`""" + num_crops: Annotated[torch.Tensor, TensorShape("bn")] @dataclass @@ -839,7 +843,7 @@ class MolmoModel(nn.Module, SupportsQuant): residual = intermediate_tensors["residual"] # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -1282,7 +1286,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -1410,28 +1414,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, **kwargs: object, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) - if images is None: - return None - - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - image_masks = kwargs.pop("image_masks", None) - if not (image_masks is None or isinstance(image_masks, - (torch.Tensor, list))): - raise ValueError("Incorrect type of image_masks. " - f"Got type: {type(image_masks)}") - feat_is_patch = kwargs.pop("feat_is_patch", None) - if not isinstance(feat_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of feat_is_patch. " - f"Got type: {type(feat_is_patch)}") - num_crops = kwargs.pop("num_crops", None) + + if images is None: + return None + if not isinstance(num_crops, (torch.Tensor, list)): raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") + num_crops = flatten_bn(num_crops, concat=True) img_patch_id = kwargs.pop("img_patch_id", None) if not isinstance(img_patch_id, torch.Tensor): @@ -1439,8 +1432,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, f"Got type: {type(img_patch_id)}") self.img_patch_id = img_patch_id.flatten().unique().item() - num_crops = flatten_bn(num_crops, concat=True) - return MolmoImageInputs( images=images, image_masks=image_masks, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 8db52a69924c9cfbe16956d12c0efb33527f2324..48ac91fa6dde037b3ac0962ed80ccca99b1b31fd 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -4,6 +4,7 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -260,7 +261,7 @@ class MPTModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for block in self.blocks[self.start_layer:self.end_layer]: + for block in islice(self.blocks, self.start_layer, self.end_layer): hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index eabf47b1aede4c5609978c2b1dd9cdd15ad27e27..10adc62d3de38394b57fb0e41f7ab28283616869 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only Nemotron model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -353,7 +354,7 @@ class NemotronModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 07cd5a4c6e24f8c97702acb4208075a7399380f7..8a563288cb4d6d592505e25276fe7897ff5c5607 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -399,8 +399,7 @@ class NemotronHModel(nn.Module): residual = None num_non_mamba_layers = 0 - for i in range(len(self.layers)): - layer = self.layers[i] + for i, layer in enumerate(self.layers): layer_mamba_cache_params = None if isinstance(layer, NemotronHMambaDecoderLayer) and mamba_cache_params: diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index a766ed9476a657dba04e9d29f694b0f0730e9959..f8e38dcd80b5ae2f751f2a67b32760100f2c819e 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only deci model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -287,8 +288,7 @@ class DeciModel(nn.Module): residual = intermediate_tensors["residual"] kv_cache_index = 0 - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): if not layer._is_no_op_attention: hidden_states, residual = layer(positions, hidden_states, residual) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 82bcd064624f39ee87c4e2f78159d6ba86545b8d..a9c7d8044e10c2ddb572213d7ef44f1811943fe0 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -458,27 +458,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - #use force_image_size to get image_size - h = w = self.config.force_image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -516,9 +495,12 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return InternVLImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, + resolve_bindings={ + "h": self.config.force_image_size, + "w": self.config.force_image_size + }, ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 4bea1392a6814972f39f2f0d1215a864c70374c8..3bbf4c67604c75d5b1f3fe37e462b32e88b6654b 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -16,7 +16,7 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, @@ -106,18 +106,19 @@ class NVLMMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 1dc4df85c1bc4016225263ac0da531bc6f9e6325..71575989565a8629fbd86a293e91d4d10726d11e 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -47,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -91,6 +92,7 @@ class OlmoAttention(nn.Module): self.total_num_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) # Rotary embeddings. @@ -114,6 +116,7 @@ class OlmoAttention(nn.Module): self.hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) def forward( @@ -142,6 +145,7 @@ class OlmoMLP(nn.Module): self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -154,6 +158,7 @@ class OlmoMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) # Activation function. @@ -165,6 +170,7 @@ class OlmoMLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward( @@ -197,7 +203,7 @@ class OlmoDecoderLayer(nn.Module): prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = OlmoMLP(config, quant_config) + self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp") # LayerNorm self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -275,7 +281,7 @@ class OlmoModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) hidden_states = layer(positions, hidden_states) @@ -326,10 +332,21 @@ class OlmoModel(nn.Module): return loaded_params -class OlmoForCausalLM(nn.Module, SupportsPP): +class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 499e6d30ed6b0834b3360b1a77a864b11f2c6d7d..bccd1b87043a5cbacd5618456cbb9c677e2938af 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -26,6 +26,7 @@ from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Optional, Union import torch @@ -33,6 +34,7 @@ from torch import nn from transformers import Olmo2Config from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -48,7 +50,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -253,6 +255,7 @@ class Olmo2DecoderLayer(nn.Module): return hidden_states +@support_torch_compile class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -303,7 +306,7 @@ class Olmo2Model(nn.Module): assert isinstance(hidden_states, torch.Tensor) # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) hidden_states = layer(positions, hidden_states) @@ -354,10 +357,21 @@ class Olmo2Model(nn.Module): return loaded_params -class Olmo2ForCausalLM(nn.Module, SupportsPP): +class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index a47c3bd41645930fba780c0441b0a2feacc1a88b..9b8525bfadecefa763a94c3ebbc74d70c07b93b8 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -15,6 +15,7 @@ """Inference-only OLMoE model compatible with HuggingFace weights.""" from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Any, Optional, Union import torch @@ -314,7 +315,7 @@ class OlmoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 9eaac1e28dcd87eb09a94149c938c690be754132..b92e586f0bf21ec723f62d02c0e88cbc16f23df7 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -20,6 +20,7 @@ # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -269,7 +270,7 @@ class OPTDecoder(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index d121188ba5d4a91ca435274f6a0b74b590e84b15..add751ebf09ccb3bafaf3ba8bb919397022adc99 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -7,6 +7,7 @@ # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -252,7 +253,7 @@ class OrionModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 6b27980e0b0c3d743408e30664a06b57f7090872..04a06e5f9d6004b82b2481fb6564f4013e0ca52e 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -42,7 +42,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) @@ -209,7 +209,7 @@ class OvisImagePatchInputs(TypedDict): `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` """ - inducator_tokens: torch.Tensor + indicator_tokens: torch.Tensor """ Shape: `(batch_size * (num_patches + 1))` @@ -375,11 +375,12 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptReplacement]: - def get_replacement_ovis(item_idx): - grid = out_mm_kwargs["grids"][item_idx] + def get_replacement_ovis(item_idx: int): + out_item = out_mm_kwargs["image"][item_idx] + grid = out_item["grids"].data hf_processor = self.info.get_hf_processor() return hf_processor.construct_image_placeholders(grid) diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4758ef8ea5d7ccdd8c22b035d5341ac3583edf --- /dev/null +++ b/vllm/model_executor/models/ovis2_5.py @@ -0,0 +1,644 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" PyTorch Ovis model.""" +from collections.abc import Iterable, Mapping +from functools import partial +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.ovis import (OvisImagePatchInputs, + VisualEmbedding) +from vllm.model_executor.models.siglip2navit import Siglip2NavitModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, + init_vllm_registered_model, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP + +IMAGE_TOKEN = "<image>" +VIDEO_TOKEN = "<video>" +INDICATOR_IDS = [-301, -302, -303, -304] + +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "<unused0>", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", + "qwen3": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, + "qwen3": 151655, +} + + +class OvisVideoPatchInputs(TypedDict): + type: Literal["video_patches"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + indicator_tokens: torch.Tensor + """ + Shape: + `(batch_size * (num_patches + 1))` + """ + + patches_per_image: list[int] + """ + List of number of total patches for each frame in the video. + This is used to restore the first two dimensions of `flat_data`. + """ + + +def _ovis2_5_field_config(): + return dict(pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_indicator_tokens=MultiModalFieldConfig.batched("video"), + video_grids=MultiModalFieldConfig.batched("video")) + + +class VisualTokenizer(torch.nn.Module): + """ + VIT + """ + + def __init__( + self, + config: PretrainedConfig, + visual_vocab_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.vit = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.vit", + use_data_parallel=use_data_parallel, + ) + # reserved tokens for INDICATOR_IDS + head_dim = visual_vocab_size - len(INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + self.config.hidden_size * self.config.hidden_stride**2, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + model_type = config.model_type + if model_type == "siglip2_navit": + return Siglip2NavitModel(config=config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=use_data_parallel) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self) -> torch.dtype: + return next(self.head.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.head.parameters()).device + + def tokenize(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.softmax(logits, dim=-1, + dtype=torch.float32).to(logits.dtype) + return tokens + + def encode(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + features = self.vit(pixel_values, grid_thws) + # refer to qwen2.5-vl patchmerger + seq_len, _ = features.shape + features = features.reshape(seq_len // (self.config.hidden_stride**2), + -1) + + return features + + def forward(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + features = self.encode(pixel_values, grid_thws) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [#Token, VocabSize-4], + # so padding with [#Token, 4], after which, + # tokens' shape should become [#Token, VocabSize]; + tokens = torch.nn.functional.pad( + tokens, + (0, len(INDICATOR_IDS)), + mode="constant", + value=0, + ) + return tokens + + +class Ovis2_5ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs): + vit_config = self.get_hf_config().vit_config + return self.ctx.get_hf_processor( + Ovis2_5Processor, + image_pad_token=self.get_image_pad_token(), + patch_size=vit_config.patch_size, + hidden_stride=vit_config.hidden_stride, + temporal_patch_size=vit_config.temporal_patch_size, + ) + + def get_image_pad_token(self) -> str: + hf_text_config = self.get_hf_config().get_text_config() + text_model_type = hf_text_config.model_type + return IMAGE_PAD_TOKEN_MAP.get(text_model_type) + + def get_image_processor(self) -> BaseImageProcessor: + return self.get_hf_processor().image_processor # type: ignore + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + def get_image_size_with_most_features(self) -> ImageSize: + # NOTE(myselvess): max_pixels 1792 * 1792 hardcoded in original code + # TODO(myselvess): Be adjusted based on the max_pixels + return ImageSize(width=1792, height=1792) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + ) -> tuple[ImageSize, int]: + hf_config = self.get_hf_config() + vit_config = hf_config.vit_config + patch_size = vit_config.patch_size + temporal_patch_size = vit_config.temporal_patch_size + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + (-num_frames % temporal_patch_size) + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = image_height // patch_size + grid_w = image_width // patch_size + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches + return num_vision_tokens + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_width=target_width, + image_height=target_height) + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + num_frames = 0 + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + if next_max_tokens > max_tokens: + break + num_frames = next_num_frames + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = max_total_frames // max(max_videos, 1) + return max(max_frames_per_video, 1) + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[BaseImageProcessor], + ) -> int: + num_video_tokens = self.get_num_image_tokens(image_width=image_width, + image_height=image_height, + num_frames=num_frames) + return num_video_tokens + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + return mm_data + + +class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] + ): + + def visual_indicators_to_visual_tokens( + self, + visual_indicators: list[int], + ) -> list[int]: + """ + Filter image indicators placeholders and convert them to corresponding + tokens in visual tokenizer. + """ + hf_config = self.info.get_hf_config() + vte_vocab_size = hf_config.visual_vocab_size + return [ + vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1 + for x in visual_indicators if x < -300 + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + hf_processor = self.info.get_hf_processor() + + if "videos" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), True) + for grid in processed_outputs["video_grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + processed_outputs["video_indicator_tokens"] = indicator_tokens + if "images" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), False) + for grid in processed_outputs["grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + + processed_outputs["indicator_tokens"] = indicator_tokens + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + + return prompt_tokens + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _ovis2_5_field_config() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> list[PromptReplacement]: + + def get_replacement_ovis(item_idx, modality: str): + if modality == "image": + out_item = out_mm_kwargs["image"][item_idx] + grid = out_item["grids"].data + elif modality == "video": + out_item = out_mm_kwargs["video"][item_idx] + grid = out_item["video_grids"].data + hf_processor = self.info.get_hf_processor() + return hf_processor.construct_visual_placeholders(grid[0], ) + + return [ + PromptReplacement( + modality=modality, + target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN, + replacement=partial(get_replacement_ovis, modality=modality), + ) for modality in ("image", "video") + ] + + +@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, + info=Ovis2_5ProcessingInfo, + dummy_inputs=Ovis2_5DummyInputsBuilder) +class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config: PretrainedConfig = config + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "llm"), + ) + + self.visual_tokenizer = VisualTokenizer( + config=config.vit_config, + visual_vocab_size=config.visual_vocab_size, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + ) + + self.vte = VisualEmbedding(config.visual_vocab_size, + config.hidden_size) + + text_model_type = self.config.get_text_config().model_type + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + + self.make_empty_intermediate_tensors = ( + self.get_language_model().make_empty_intermediate_tensors) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + pixel_values = kwargs.pop("pixel_values", None) + indicator_tokens = kwargs.pop("indicator_tokens", None) + grids = kwargs.pop("grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError("Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}") + + return OvisImagePatchInputs( + type="image_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in flatten_bn(pixel_values) + ], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), + concat=True), + grids=flatten_bn(flatten_bn(grids), concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + pixel_values = kwargs.pop("video_pixel_values", None) + indicator_tokens = kwargs.pop("video_indicator_tokens", None) + grids = kwargs.pop("video_grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError("Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}") + + return OvisVideoPatchInputs( + type="video_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in flatten_bn(pixel_values) + ], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), + concat=True), + grids=flatten_bn(flatten_bn(grids), concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs] + ) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + indicator_tokens = image_input["indicator_tokens"] + grid_thws = image_input["grids"] + + indicator_per_image = list( + map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) + + target_dtype = self.visual_tokenizer.dtype + visual_tokens = self.visual_tokenizer( + image_patches_flat.to(target_dtype), grid_thws) + + visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. + indicator_embeds = self.vte(indicator_tokens) + + visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) + indicator_embeds_per_image = indicator_embeds.split( + indicator_per_image) + + vision_embeddings = [] + for indicator, visual in zip(indicator_embeds_per_image, + visual_embeds_per_image): + vision_embeddings_per_image = [] + visual = visual.unsqueeze(0) + for i in range(visual.shape[0]): + vision_embeddings_per_image.append( + torch.cat([indicator[i:i + 1], visual[i]], dim=0)) + vision_embeddings_per_image.append(indicator[i + 1:]) + vision_embeddings.append( + torch.cat(vision_embeddings_per_image, dim=0)) + return tuple(vision_embeddings) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "indicator_tokens", + "grids") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("video_pixel_values", "video_indicator_tokens", + "video_grids") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] + + multimodal_embeddings: tuple[torch.Tensor, ...] = () + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_image_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + tmp = torch.concat(multimodal_embeddings, dim=0) + inputs_embeds[input_ids == self.image_pad_token_id] = tmp + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + # up until here we have a inputs_embeds 100% numerical identity + # between the OG HF Transformers implementation and ours + hidden_states = self.llm( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.llm.compute_logits(hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_language_model(self) -> torch.nn.Module: + return self.llm diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b1f2e53b0c7128596376b0af48af3d7ad6382a19..b74a09ee92c337bee2ad409b2fc4cbd7d6f1a216 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -21,6 +21,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel @@ -32,19 +33,27 @@ from .vision import get_vision_encoder_info logger = init_logger(__name__) -class PaliGemmaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" - +class PaliGemmaImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class PaliGemmaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class PaliGemmaImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, @@ -146,7 +155,7 @@ class PaliGemmaMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -194,10 +203,13 @@ class PaliGemmaMultiModalProcessor( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalInputs: - mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + mm_inputs = super().apply(prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() @@ -280,19 +292,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -302,22 +301,17 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values, concat=True) - return PaliGemmaImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - ) + h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs(type="pixel_values", + data=pixel_values, + resolve_bindings={ + "h": h, + "w": w + }) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - image_embeds = flatten_bn(image_embeds, concat=True) return PaliGemmaImageEmbeddingInputs( diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index f8db99eb92ba87b64147b5c985e7a74200f435be..6bdd38d0688003b77b4a0f3f235120391574f03c 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -23,6 +23,7 @@ # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -255,7 +256,7 @@ class PersimmonModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 21d517b3a490fc958f05c10bd88b5cacd2c6ce2d..789b24eb0f6be05015be4beca54ed33d305a51e0 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -38,6 +38,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -240,7 +241,7 @@ class PhiModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 9ef4f8371eb3d9bc5969c1efc679d6b5cf4ee72c..4522c7043d01a224094e8311c14c4465dbeb7a4e 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,15 +32,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + ResolvedPromptUpdate) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -410,7 +412,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore @@ -431,24 +433,38 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): return [_IMAGE_TOKEN_ID] * num_image_tokens - num_images = mm_items.get_count("image", strict=False) - return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:num_images] + ) ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + + return new_update + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: # align to hf behavior when there are images - if len(mm_item_counts): + if len(mm_prompt_updates): tokenizer = self.info.get_tokenizer() # to decode token_ids to the original text, we need to # 1. remove the first bos token @@ -484,7 +500,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): token_ids, text, placeholders = super()._apply_prompt_updates( token_ids=token_ids, mm_prompt_updates=mm_prompt_updates, - mm_item_counts=mm_item_counts, ) # Keep the behavior in line with HF processor diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index e13b8276bf17aae7cc01cfab02923a330c2ef198..6d973a964de0449e66302710747916d5cd50bb55 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -30,7 +30,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser) @@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal @@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module): return loaded_params -class Phi4MMImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] +class Phi4MMImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` + type: Literal["pixel_values"] - This should be in `(height, width)` format. - """ + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] -class Phi4MMImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class Phi4MMImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match language model backbone) + """ + + type: Literal["image_embeds"] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", "h"), + ] + + +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - f: Number of Mel filterbank bins (80) + - t: Time frames (M) + """ -class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] + + +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ -class Phi4MMAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] @@ -1029,11 +1070,11 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.vocab[tokenizer.image_token] - audio_token_id = tokenizer.vocab[tokenizer.audio_token] + image_token_id: int = tokenizer.vocab[tokenizer.image_token] + audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) audio_processor = self.info.get_feature_extractor( @@ -1053,9 +1094,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [image_token_id] * num_image_tokens - - return image_tokens + return [image_token_id] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -1066,9 +1105,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [audio_token_id] * audio_embed_size - - return audio_tokens + return [audio_token_id] * audio_embed_size return [ PromptReplacement( @@ -1174,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - return Phi4MMAudioFeatureInputs(type="audio_features", data=flatten_bn(audio_features)) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1263,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ @@ -1273,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 73e8446e6dea74b82cb9f15a626636277c400649..352ae4064cc6174c021f9a5a388fcd50d899a0d6 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -21,16 +21,17 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, ResolvedPromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal @@ -262,9 +263,9 @@ class Phi4MMImageEncoder(nn.Module): img_features.shape[1])) assert base_feat_height == base_feat_height_target \ and base_feat_width == base_feat_height_target, \ - f'base_feat_height: {base_feat_height},"\ - f" base_feat_width: {base_feat_width}, "\ - f"expect {base_feat_height_target} features for hd transform' + (f"base_feat_height: {base_feat_height}, " + f"base_feat_width: {base_feat_width}, " + f"expect {base_feat_height_target} features for hd transform") # bs x max_num_crops x (24x24) x C img_features = img_features.view(bs, -1, @@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module): return img_set_tensor -class Phi4MMImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] +class Phi4MMImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` + type: Literal["pixel_values"] - This should be in `(height, width)` format. - """ + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] + + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] -class Phi4MMAudioFeatureInputs(TypedDict): +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - t: Time frames (M) + """ + type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] -class Phi4MMAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] @@ -802,7 +833,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_tokens: list[str] = self.info.image_tokens # type: ignore audio_tokens: list[str] = self.info.audio_tokens # type: ignore @@ -824,9 +855,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens - - return image_tokens + return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -837,28 +866,39 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - return audio_tokens - - num_images = mm_items.get_count("image", strict=False) - num_audios = mm_items.get_count("audio", strict=False) - - image_repl = [ + return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_image_replacement_phi4mm, - ) for image_token in image_tokens[:num_images] - ] - audio_repl = [ + ), PromptReplacement( modality="audio", - target=audio_token, + target=audio_tokens.__getitem__, replacement=get_audio_replacement_phi4mm, - ) for audio_token in audio_tokens[:num_audios] + ), ] - return image_repl + audio_repl + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + image_tokens: list[str] = self.info.image_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + elif cached_update.modality == "audio": + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + new_update = new_update.with_target(audio_tokens[new_item_idx]) + + return new_update @MULTIMODAL_REGISTRY.register_processor( @@ -976,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - return Phi4MMAudioFeatureInputs(type="audio_features", data=flatten_bn(audio_features)) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1022,8 +1054,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ] return audio_embeds - def _parse_and_validate_image_input(self, - **kwargs: object) -> Optional[dict]: + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]: input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") if input_image_embeds is None: return None @@ -1065,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ @@ -1075,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index cfe0982204fa95816a11b8cea0c756de4ad6feaa..15ae081a9f5fca000283bd92131901d427df9936 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only PhiMoE model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -500,7 +501,7 @@ class PhiMoEModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 3f7375e650d234a39f7b7f94f6a2a4055bbf88cd..f9a2f9df348aa7e685b93b3f8e675e81d1fe38d3 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -15,7 +15,7 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image -from transformers import PixtralVisionConfig, TensorType +from transformers import BatchFeature, PixtralVisionConfig, TensorType from transformers.image_utils import ImageInput from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens) @@ -33,13 +33,14 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs @@ -47,6 +48,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -67,15 +69,20 @@ except ImportError: PATCH_MERGE = "patch_merge" -class PixtralImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - - images: Union[torch.Tensor, list[torch.Tensor]] +class PixtralImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" + + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] class PixtralProcessorAdapter: @@ -156,10 +163,12 @@ class PixtralProcessorAdapter: images_processed.append(image_processed) images_tokens.append(image_tokens) - return { - "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), - "images": images_processed, - } + return BatchFeature({ + "input_ids": + torch.cat(images_tokens)[None].expand(len(text), -1), + "images": + images_processed, + }) class PixtralProcessingInfo(BaseProcessingInfo): @@ -273,7 +282,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -307,24 +316,18 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: - ( - prompt_ids, - mm_kwargs, - mm_hashes, - _, - ) = super()._cached_apply_hf_processor( + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, @@ -388,10 +391,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e5034b536266a5605d891a8d14721dedc007030c..7f70e44b10a6d8eaceb610f6d4facb7ed09f8c78 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only PLaMo2 model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional import torch @@ -614,7 +615,7 @@ class Plamo2Decoder(torch.nn.Module): mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: mamba_cache_index = 0 - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): layer_mamba_cache_params = None if layer.is_mamba: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 68488829071faa0113aadb27aa7efbd5f36896ea..2edc357d2df1bec8975d95c56c0dd6216a3cf593 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -18,7 +18,7 @@ """Inference-only IBM/NASA Prithvi Geospatial model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -27,21 +27,61 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import ( - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, - default_pooling_type) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalFieldElem, MultiModalInputs, - MultiModalKwargs, MultiModalKwargsItem, - MultiModalSharedField, PlaceholderRange) -from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargsItems, + PlaceholderRange) +from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from .interfaces import (IsAttentionFree, MultiModalEmbeddings, + SupportsMultiModal) +from .interfaces_base import default_pooling_type + + +def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): + # This model receives in input a multi-dimensional tensor representing + # a single image patch and therefore it is not to be split + # into multiple elements, but rather to be considered a single one. + # Hence, the decision of using a MultiModalSharedField. + # The expected shape is (num_channels, width, height). + + # This model however allows the user to also submit multiple image + # patches as a batch, adding a further dimension to the above shape. + # At this stage we only support submitting one patch per request and + # batching is achieved via vLLM batching. + # TODO (christian-pinto): enable support for multi patch requests + # in tandem with vLLM batching. + return dict( + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + ) + + +class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={"pixel_values", "location_coords"}, + fields_factory=_prithvi_field_config, + ) + + return super()._parse_image_data(data) + class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): @@ -63,32 +103,32 @@ class PrithviGeoSpatialMAEInputBuilder( # This model input is fixed and is in the form of a torch Tensor. # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. - return { + image_data = { "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } + return {"image": image_data} + class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): + def _get_data_parser(self) -> MultiModalDataParser: + return PrithviGeoSpatialMAEMultiModalDataParser() + def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - ) + return _prithvi_field_config(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: return [] @@ -98,46 +138,35 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalInputs: - mm_kwargs = {} - - for k, v in mm_data.items(): - if isinstance(v, dict) and k == "image": - mm_kwargs.update(v) - else: - mm_kwargs[k] = v + if "image" in mm_data: + image_data = mm_data["image"] + else: + image_data = mm_data + mm_data = {"image": mm_data} + + mm_items = self._to_mm_items(mm_data) + tokenization_kwargs = tokenization_kwargs or {} + mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else + self._hash_mm_items(mm_items, hf_processor_mm_kwargs, + tokenization_kwargs)) mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} - # This model receives in input a multi-dimensional tensor representing - # a single image patch and therefore it is not to be split - # into multiple elements, but rather to be considered a single one. - # Hence, the decision of using a MultiModalSharedField. - # The expected shape is (num_channels, width, height). - - # This model however allows the user to also submit multiple image - # patches as a batch, adding a further dimension to the above shape. - # At this stage we only support submitting one patch per request and - # batching is achieved via vLLM batching. - # TODO (christian-pinto): enable support for multi patch requests - # in tandem with vLLM batching. - multimodal_kwargs_items = [ - MultiModalKwargsItem.from_elems([ - MultiModalFieldElem( - modality="image", - key=key, - data=data, - field=MultiModalSharedField(1), - ) for key, data in mm_kwargs.items() - ]) - ] + mm_processed_data = BatchFeature(image_data) + + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, + hf_processor_mm_kwargs), + ) return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs(multimodal_kwargs_items), - mm_hashes=None, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, mm_placeholders=mm_placeholders, ) @@ -148,10 +177,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder, ) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, - SupportsMultiModalWithRawInput): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): """Prithvi Masked Autoencoder""" + supports_multimodal_raw_input_only = True is_pooling_model = True @classmethod diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 7f8a565b37d3aac0be73931114773ba2413a0cdf..4f7501ff4c31e559cf8d572ef9989bf0092b8bc0 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -8,6 +8,7 @@ """Inference-only QWen model compatible with HuggingFace weights.""" import json from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -245,7 +246,7 @@ class QWenModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index dcd7096907eac01ed8c91cfb6a8ad1a6475a5e47..4b839b0c458849134b66c6f5be001da0872a3f84 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -25,6 +25,7 @@ # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -36,6 +37,7 @@ import re import vllm.envs as envs from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -55,7 +57,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -165,7 +167,9 @@ class Qwen2Attention(nn.Module): rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, @@ -353,7 +357,7 @@ class Qwen2Model(nn.Module): self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -378,7 +382,7 @@ class Qwen2Model(nn.Module): aux_hidden_states = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer)): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) @@ -514,7 +518,7 @@ class Qwen2Model(nn.Module): return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -570,6 +574,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index e95295c31885a05eddca61dd4107df86bdaa5acf..5c64c81547e651650c4e835e857bf3671c92153d 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,7 +25,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -47,18 +47,19 @@ from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioInputs, Qwen2AudioProcessingInfo, + Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -78,37 +79,57 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) -def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_feature_lengths = hf_inputs.get("audio_feature_lengths", - torch.empty((0, ))) - - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) - - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) +def create_qwen2_5_omni_thinker_field_factory( + spatial_merge_size: int +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, + MultiModalFieldConfig]]: + + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, + torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + num_videos = len(video_grid_sizes) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + use_audio_in_video=MultiModalFieldConfig.shared( + "video", num_videos), + ) - return dict( - input_audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_feature_lengths, dim=1), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - audio_feature_lengths=MultiModalFieldConfig.batched("audio"), - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - second_per_grid_ts=MultiModalFieldConfig.batched("video"), - ) + return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(self._spatial_merge_size, *args, **kwargs) + def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -120,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): required_fields={ "input_audio_features", "audio_feature_lengths" }, - fields_factory=_qwen2_5_omni_thinker_field_config, + fields_factory=create_qwen2_5_omni_thinker_field_factory( + self._spatial_merge_size), ) return super()._parse_audio_data(data) @@ -210,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( + spatial_merge_size=self.info.get_hf_config( + ).vision_config.spatial_merge_size, target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -246,6 +270,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( if ('audio_feature_lengths' not in hf_inputs and feature_attention_mask is not None): hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + + video_second_per_grid = hf_inputs.get("video_second_per_grid", None) + if video_second_per_grid is not None: + hf_inputs["second_per_grid_ts"] = video_second_per_grid + + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) + hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video) + return hf_inputs def _get_mm_fields_config( @@ -253,38 +285,32 @@ class Qwen2_5OmniThinkerMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2_5_omni_thinker_field_config(hf_inputs) + return create_qwen2_5_omni_thinker_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = (all( + item["use_audio_in_video"].data + for item in mm_kwargs["video"]) if "video" in mm_kwargs else False) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, @@ -301,7 +327,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders( mm_placeholders, @@ -311,16 +336,13 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) - if use_audio_in_video: - mm_kwargs["use_audio_in_video"] = True - return prompt_ids, prompt, mm_placeholders def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -335,8 +357,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor( image_token_id = vocab[image_token] video_token_id = vocab[video_token] - audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") - feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + out_mm_data = out_mm_kwargs.get_data() + audio_feature_lengths = out_mm_data.get("audio_feature_lengths") + feature_attention_mask = out_mm_data.get("feature_attention_mask") if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: @@ -366,7 +389,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( return [audio_token_id] * num_features def get_replacement_qwen2_vision(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) merge_length = image_processor.merge_size**2 @@ -382,7 +405,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( audio_num_features = audio_output_lengths[audio_in_video_item_idx + item_idx] - video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 @@ -511,7 +534,7 @@ class Qwen2_5OmniConditionalGenerationMixin: return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]: input_audio_features = kwargs.pop('input_audio_features', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) @@ -525,9 +548,10 @@ class Qwen2_5OmniConditionalGenerationMixin: if not isinstance(input_audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio input features. " f"Got type: {type(input_audio_features)}") - return Qwen2AudioInputs(input_features=input_audio_features, - audio_feature_lengths=audio_feature_lengths, - feature_attention_mask=feature_attention_mask) + return Qwen2AudioFeatureInputs( + input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask) def _parse_and_validate_image_input( self, @@ -607,7 +631,7 @@ class Qwen2_5OmniConditionalGenerationMixin: def _process_audio_input( self, - audio_input: Qwen2AudioInputs, + audio_input: Qwen2AudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b44e8ddf2d4fc15177db65fc4ab810b2f14013b9..a43312019d7f2978529813b06cc027bf2c941264 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -45,10 +45,14 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + MergedReplicatedLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -57,6 +61,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -136,7 +141,7 @@ class Qwen2_5_VLVideoPixelInputs(TypedDict): second_per_grid_ts: torch.Tensor """ - The video time interval (in seconds) for each grid along the temporal + The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. """ @@ -176,19 +181,25 @@ class Qwen2_5_VisionMLP(nn.Module): bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( + cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else + MergedColumnParallelLinear) + self.gate_up_proj = cls_gate_up_proj( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + + cls_down_proj = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.down_proj = cls_down_proj(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -226,28 +237,42 @@ class Qwen2_5_VisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = (1 if use_data_parallel else + parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + if use_data_parallel: + self.qkv = ReplicatedLinear(embed_dim, + self.hidden_size_per_attention_head * + 3 * num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + + else: + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + + cls_proj = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.proj = cls_proj(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -309,8 +334,6 @@ class Qwen2_5_VisionAttention(nn.Module): if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) # if self.attn_backend == _Backend.ROCM_AITER_FA: # from aiter import flash_attn_varlen_func # else: @@ -377,23 +400,27 @@ class Qwen2_5_VisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) - self.attn = Qwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) def forward( self, @@ -438,7 +465,8 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): L, C = x.shape x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = x.to(memory_format=torch.channels_last_3d) + if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1': + x = x.to(memory_format=torch.channels_last_3d) x = self.proj(x).view(L, self.hidden_size) return x @@ -453,24 +481,30 @@ class Qwen2_5_VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) + + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + cls_fc1(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), + cls_fc2(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -522,6 +556,7 @@ class Qwen2_5_VisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -531,6 +566,8 @@ class Qwen2_5_VisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.out_hidden_size # args for get_window_index_thw self.window_size = vision_config.window_size @@ -558,7 +595,8 @@ class Qwen2_5_VisionTransformer(nn.Module): vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( @@ -568,6 +606,7 @@ class Qwen2_5_VisionTransformer(nn.Module): spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -786,7 +825,6 @@ class Qwen2_5_VisionTransformer(nn.Module): if weight_name not in name: continue name = name.replace(weight_name, param_name) - param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -877,6 +915,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -888,6 +931,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -902,6 +947,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config @@ -913,6 +959,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, quant_config=self._maybe_ignore_quant_config( self.quant_config), prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) else: self.visual = None @@ -1035,7 +1082,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list) + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync @@ -1057,8 +1110,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list) + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index bbae82c7a40d32601a6bbf764e01c19ceca40b4a..5f6c3b55fd35649e8450d335f3511cac9b5c93b9 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union import torch import torch.nn as nn @@ -36,9 +36,11 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, +from vllm.multimodal.inputs import (AudioItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, + ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -52,7 +54,8 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model, # # === Audio Inputs === # -class Qwen2AudioInputs(TypedDict): +class Qwen2AudioFeatureInputs(TypedDict): + type: Literal["audio_features"] input_features: torch.Tensor """Shape: `(num_audios, num_mel_bins, 3000)`""" @@ -60,6 +63,16 @@ class Qwen2AudioInputs(TypedDict): """Shape: `(num_audios, 3000)`""" +class Qwen2AudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + audio_embeds: list[torch.Tensor] + """Shape: `(num_audio_features, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + +Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] + # === Audio Encoder === # @@ -128,12 +141,38 @@ class Qwen2AudioDummyInputsBuilder( } +def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + ) + + +class Qwen2AudioMultiModalDataParser(MultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_qwen2audio_field_config, + ) + + return super()._parse_audio_data(data) + + class Qwen2AudioMultiModalProcessor( BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + return Qwen2AudioMultiModalDataParser( + target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, @@ -173,17 +212,15 @@ class Qwen2AudioMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - input_features=MultiModalFieldConfig.batched("audio"), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - ) + return _qwen2audio_field_config(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -199,7 +236,8 @@ class Qwen2AudioMultiModalProcessor( audio_bos_id = vocab[audio_bos_token] audio_eos_id = vocab[audio_eos_token] - feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") if feature_attention_mask is None: audio_output_lengths = [] else: @@ -210,7 +248,15 @@ class Qwen2AudioMultiModalProcessor( audio_output_lengths = audio_output_lens.tolist() def get_replacement_qwen2_audio(item_idx: int): - num_features = audio_output_lengths[item_idx] + + if audio_output_lengths: + num_features = audio_output_lengths[item_idx] + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + assert len(audio_embeds.shape + ) == 2, "audio_embeds must be a 2D tensor" + num_features = audio_embeds.shape[0] + if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio_len = audios.get_audio_length(item_idx) @@ -285,21 +331,39 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Qwen2AudioInputs]: input_features = kwargs.pop('input_features', None) + audio_embeds = kwargs.pop('audio_embeds', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) - if input_features is None: + + if input_features is None and audio_embeds is None: return None - input_features = self._validate_and_reshape_mm_tensor( - input_features, 'input_features') - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') - if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_features)}") - return Qwen2AudioInputs(input_features=input_features, - feature_attention_mask=feature_attention_mask) - - def _process_audio_input(self, - audio_input: Qwen2AudioInputs) -> torch.Tensor: + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + audio_embeds = self._validate_and_reshape_mm_tensor( + audio_embeds, "audio_embeds") + return Qwen2AudioEmbeddingInputs(type="audio_embeds", + audio_embeds=audio_embeds) + + if input_features is not None: + input_features = self._validate_and_reshape_mm_tensor( + input_features, 'input_features') + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + return Qwen2AudioFeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=feature_attention_mask) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: Qwen2AudioInputs + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if audio_input["type"] == "audio_embeds": + audio_embeds = audio_input["audio_embeds"] + return tuple(audio_embeds) input_features = audio_input["input_features"] feature_attention_mask = audio_input["feature_attention_mask"] diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 2ff4013c6cf4cc0e069916ea2204a0bfd67376d7..5c1d1a7fae696539f136b6fcda30bcb1bd2fff10 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -25,6 +25,7 @@ # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -395,7 +396,7 @@ class Qwen2MoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e0a30e04c602a86db205886295ee8d1e39483f36..421b43563bade76fa5993d7ee7669e463b9ef571 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -18,7 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 4375a693e2e38dc340faa6df799fc376dcfa74c2..f3e8df99f0fc4172201ac436a968d5654a40b0eb 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -58,7 +58,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) @@ -335,8 +335,6 @@ class Qwen2VisionAttention(nn.Module): k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) # if self.attn_backend == _Backend.ROCM_AITER_FA: # from aiter import flash_attn_varlen_func # else: @@ -467,7 +465,8 @@ class Qwen2VisionPatchEmbed(nn.Module): L, C = x.shape x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = x.to(memory_format=torch.channels_last_3d) + if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1': + x = x.to(memory_format=torch.channels_last_3d) x = self.proj(x).view(L, self.embed_dim) return x @@ -764,29 +763,46 @@ class Qwen2VisionTransformer(nn.Module): return loaded_params -def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) - - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) - - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) +def _create_qwen2vl_field_factory( + spatial_merge_size: int +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: + + def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + return _qwen2vl_field_config class Qwen2VLMultiModalDataParser(MultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(*args, **kwargs) + def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -796,7 +812,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_image_data(data) @@ -810,7 +827,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_video_data(data) @@ -1032,13 +1050,14 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2VLMultiModalDataParser() + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -1054,7 +1073,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -1074,7 +1094,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index be539f6f54828b88d8c3fd459e79a7127f6cb2ff..ecec15862b12ada65577a38da84ec9b48cfe0ed3 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -314,10 +314,10 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index e717d6c1ae400cd74ee94a60936aa09d811bd13e..9f43162d5ecce1a8463878ff070cf352043c4ec0 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -24,6 +24,7 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import os @@ -47,6 +48,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -127,11 +131,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size @@ -145,18 +149,31 @@ class Qwen3MoeSparseMoeBlock(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, + reduce_results=True, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization while AutoRound does. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4, + # and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq + if isinstance( + quant_config, + (GPTQConfig, + GPTQMarlinConfig)) and not quant_config.autoround_version: + return None + return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -169,10 +186,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) - return final_hidden_states.view(orig_shape) @@ -379,7 +392,8 @@ class Qwen3MoeModel(nn.Module): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb - self.num_redundant_experts = parallel_config.num_redundant_experts + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -438,8 +452,7 @@ class Qwen3MoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -756,4 +769,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() \ No newline at end of file + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 4c3fd6b5156d00b440c804c45a0325662d28a43e..90200f319464b606ef8abeec62a54e1e52aac555 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -11,7 +11,7 @@ import math import unicodedata from collections.abc import Collection, Mapping, Sequence, Set from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Callable, Literal, Optional, Union import regex as re import torch @@ -33,13 +33,14 @@ from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -47,26 +48,34 @@ from .qwen import QWenBaseModel, QWenModel from .utils import flatten_bn, merge_multimodal_embeddings -class QwenImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor +class QwenImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 3, image_size, image_size)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + Note that image_size is the value in the vision config to which we resize the image to in the normalization transform. Currently multi-image support can only be leveraged by passing image embeddings directly. """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class QwenImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, 256, hidden_size)` - +class QwenImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size (256) + - hs: Hidden size + `hidden_size` must match the hidden size of the language model backbone and is stored in the visual config of the model if we have one. """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")] QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] @@ -627,7 +636,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() special_tokens: dict[str, @@ -697,19 +706,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, self.transformer: QwenVLModel - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.visual["image_size"] - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[QwenImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -720,10 +716,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + expected_h = expected_w = self.config.visual["image_size"] + resolve_bindings = {"h": expected_h, "w": expected_w} + return QwenImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + data=flatten_bn(pixel_values, concat=True), + resolve_bindings=resolve_bindings, ) if image_embeds is not None: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2d9a22da7e006928a5a04e2d9fefcb623783e41c..5165743ef3f64bf7b2ae7d35d2fdae20cbcb66e1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -25,17 +25,21 @@ from vllm.logger import init_logger from vllm.transformers_utils.dynamic_module import ( try_get_class_from_dynamic_module) -from .interfaces import (get_default_pooling_type, has_inner_state, has_noops, - is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, supports_v0_only) -from .interfaces_base import is_pooling_model, is_text_generation_model +from .interfaces import (has_inner_state, has_noops, is_attention_free, + is_hybrid, supports_cross_encoding, + supports_multimodal, + supports_multimodal_encoder_tp_data, + supports_multimodal_raw_input_only, supports_pp, + supports_transcription, supports_v0_only) +from .interfaces_base import (get_default_pooling_type, is_pooling_model, + is_text_generation_model) logger = init_logger(__name__) # yapf: disable _TEXT_GENERATION_MODELS = { # [Decoder-only] + "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"), @@ -94,6 +98,7 @@ _TEXT_GENERATION_MODELS = { "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 # For decapoda-research/llama-* @@ -130,6 +135,7 @@ _TEXT_GENERATION_MODELS = { "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), @@ -187,12 +193,14 @@ _EMBEDDING_MODELS = { _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "GteNewForSequenceClassification": ("bert_with_rope", + "GteNewForSequenceClassification"), + "ModernBertForSequenceClassification": ("modernbert", + "ModernBertForSequenceClassification"), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), - "ModernBertForSequenceClassification": ("modernbert", - "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } @@ -205,6 +213,7 @@ _MULTIMODAL_MODELS = { "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501 @@ -215,9 +224,11 @@ _MULTIMODAL_MODELS = { "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 + "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), + "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), @@ -232,6 +243,7 @@ _MULTIMODAL_MODELS = { "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), + "Ovis2_5": ("ovis2_5", "Ovis2_5"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), @@ -249,6 +261,7 @@ _MULTIMODAL_MODELS = { "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] + "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 @@ -264,7 +277,9 @@ _SPECULATIVE_DECODING_MODELS = { "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), # Temporarily disabled. @@ -315,7 +330,8 @@ class _ModelInfo: default_pooling_type: str supports_cross_encoding: bool supports_multimodal: bool - supports_multimodal_raw_input: bool + supports_multimodal_raw_input_only: bool + supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -334,7 +350,10 @@ class _ModelInfo: default_pooling_type=get_default_pooling_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), - supports_multimodal_raw_input=supports_multimodal_raw_input(model), + supports_multimodal_raw_input_only= + supports_multimodal_raw_input_only(model), + supports_multimodal_encoder_tp_data= + supports_multimodal_encoder_tp_data(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), @@ -729,13 +748,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures, model_config) return model_cls.supports_multimodal - def supports_multimodal_raw_input( + def is_multimodal_raw_input_only_model( self, architectures: Union[str, list[str]], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) - return model_cls.supports_multimodal_raw_input + return model_cls.supports_multimodal_raw_input_only def is_pp_supported_model( self, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 32a4a2c9a26941bb902ee87d0d34e87bd8452ceb..2bfa51162910b56e515c952222052162f60b539a 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,6 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -23,7 +22,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type class RobertaEmbedding(nn.Module): @@ -100,7 +100,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id def forward( self, @@ -178,7 +178,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -233,58 +233,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): intermediate_tensors=intermediate_tensors) -# Adapted from transformers -def create_position_ids_from_input_ids(input_ids, - padding_idx, - past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. - Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ - # The series of casts and type-conversions here are carefully - # balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - - incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + - past_key_values_length) * mask - - return incremental_indices.long() + padding_idx - - def replace_roberta_positions(input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int) -> None: - - seq_lens: Optional[torch.Tensor] = None - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: # can be None during warmup - if isinstance(attn_metadata, dict): - attn_metadata = next(iter(attn_metadata.values())) - # TODO: remove "seq_lens_tensor" after V0 is removed - seq_lens = getattr(attn_metadata, "seq_lens_tensor", - getattr(attn_metadata, "seq_lens", None)) - - if seq_lens is not None: - assert isinstance(seq_lens, torch.Tensor) - - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids[:torch.sum(seq_lens)], - seq_lens.tolist()) - - offset = 0 - for tokens in token_list: - length = tokens.shape[0] - position_ids[offset:offset+length] = \ - create_position_ids_from_input_ids(tokens, padding_idx) - offset = offset + length + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + # vllm does not use padding tokens, let's make things simpler + position_ids += padding_idx + 1 diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py new file mode 100644 index 0000000000000000000000000000000000000000..efdb010046634a4e8a83afc04660863214a7db28 --- /dev/null +++ b/vllm/model_executor/models/rvl.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping + +import torch +import torch.nn as nn +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict + +from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo) +from .llava_onevision import LlavaOnevisionForConditionalGeneration +from .utils import WeightsMapper + + +class RVLProcessingInfo(LlavaNextProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + +class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + image_token = "<image>" + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class RVLMultiModalProjector(nn.Module): + + def __init__(self, config): + super().__init__() + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-06) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + self.act = GELUActivation() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + def forward(self, image_feature: torch.Tensor) -> torch.Tensor: + image_feature = self.pre_norm(image_feature) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=RVLProcessingInfo, + dummy_inputs=RVLDummyInputsBuilder, +) +class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers + # v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + self.multi_modal_projector = RVLMultiModalProjector(config) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c7c700f8fa1557f38518536f208f43ae048690 --- /dev/null +++ b/vllm/model_executor/models/seed_oss.py @@ -0,0 +1,488 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Seed team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only SeedOss model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from itertools import islice +from typing import Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig as SeedOssConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class SeedOssMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SeedOssAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + self.head_dim = head_dim + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class SeedOssDecoderLayer(nn.Module): + + def __init__( + self, + config: SeedOssConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, SeedOss uses causal attention as it is a + # decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = SeedOssAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = SeedOssMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class SeedOssModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + assert config.max_window_layers == config.num_hidden_layers, ( + "Sliding window for some but all layers is not supported. " + "This model uses sliding window but `max_window_layers` = {} " + "is less than `num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to SeedDecoderLayer + decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = SeedOssModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py new file mode 100644 index 0000000000000000000000000000000000000000..c6244fb3b3e6a49ceea6063b980bd652fe511fde --- /dev/null +++ b/vllm/model_executor/models/siglip2navit.py @@ -0,0 +1,688 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F +from transformers import Siglip2VisionConfig +from transformers.configuration_utils import PretrainedConfig + +from vllm.config import QuantizationConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import _Backend + +from .vision import get_vit_attn_backend + + +class VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Siglip2VisionEmbeddings(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + self.image_size = config.image_size + self.num_patches = config.num_patches + self.preserve_original_pe = config.preserve_original_pe + self.hidden_stride = config.hidden_stride + + # siglip2 naflex + if self.num_patches > 0: + self.patch_embedding = ReplicatedLinear( + input_size=config.num_channels * self.patch_size * + self.patch_size, + output_size=self.embed_dim, + return_bias=False, + ) + if self.preserve_original_pe: + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + else: + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + if self.preserve_original_pe: + self.num_patches = (self.image_size // self.patch_size)**2 + self.position_embedding_size = (self.image_size // + self.patch_size) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape ( + num_patches, + num_channels * temporal_patch_size * patch_size * patch_size + ) + grid_thws: (`torch.LongTensor`): + grid shape (num_patches, 3) + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + if isinstance(self.patch_embedding, LinearBase): + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + elif isinstance(self.patch_embedding, nn.Conv2d): + pixel_values = pixel_values.view( + -1, self.config.num_channels * self.config.temporal_patch_size, + self.patch_size, self.patch_size) + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + patch_embeds = patch_embeds.reshape(-1, self.embed_dim) + + if self.preserve_original_pe: + assert grid_thws is not None + pos_embed_new = torch.zeros_like(patch_embeds) + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, + -1).unsqueeze(0).permute(0, 3, 1, 2) + cnt = 0 + for t, h, w in grid_thws: + volume = t * h * w + pe = F.interpolate(positional_embeddings, + size=(h, w), + mode='bicubic', + align_corners=False) + pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) + pe = pe[0].repeat(t, 1) + pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, + w // self.hidden_stride, self.hidden_stride, + -1) + pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1) + pos_embed_new[cnt:cnt + volume] = pe + cnt += volume + patch_embeds = patch_embeds + pos_embed_new + + return patch_embeds + + +# copy from flash_attn/layers/rotary.py +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_flash_attn_backend: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + if is_flash_attn_backend: + from flash_attn.layers.rotary import apply_rotary_emb + apply_rotary_emb_func = apply_rotary_emb + else: + apply_rotary_emb_func = apply_rotary_emb_torch + q_embed = apply_rotary_emb_func(q.float(), cos.float(), + sin.float()).type_as(q) + k_embed = apply_rotary_emb_func(k.float(), cos.float(), + sin.float()).type_as(k) + return q_embed, k_embed + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + # TODO(Isotr0py): Enable data parallel after we support + # disabling TP on parallel linear layer + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.use_rope = config.use_rope + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, + _Backend.ROCM_AITER_FA + }: + self.attn_backend = _Backend.TORCH_SDPA + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, + torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + seq_length, embed_dim = hidden_states.shape + + qkv_states, _ = self.qkv_proj(hidden_states) + queries, keys, values = qkv_states.chunk(3, dim=-1) + + queries = queries.view(seq_length, self.num_heads_per_partition, + self.head_dim) + keys = keys.view(seq_length, self.num_heads_per_partition, + self.head_dim) + values = values.view(seq_length, self.num_heads_per_partition, + self.head_dim) + + if self.use_rope: + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0), + keys.unsqueeze(0), cos, sin, + self.is_flash_attn_backend) + queries = queries.squeeze(0) + keys = keys.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + if self.is_flash_attn_backend: + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( + queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, + max_seqlen).reshape(seq_length, -1) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + batch_size = cu_seqlens.shape[0] - 1 + outputs = [] + cu = cu_seqlens.tolist() + for i in range(batch_size): + start_idx = cu[i] + end_idx = cu[i + 1] + + # Each sequence is processed independently. + q_i = queries[start_idx:end_idx].unsqueeze(0) + k_i = keys[start_idx:end_idx].unsqueeze(0) + v_i = values[start_idx:end_idx].unsqueeze(0) + + # (1, seq_len, num_heads, head_dim) -> + # (1, num_heads, seq_len, head_dim) + q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)] + + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) + output_i = output_i.transpose(1, 2).reshape( + end_idx - start_idx, -1) + outputs.append(output_i) + + attn_output = torch.cat(outputs, dim=0) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class Siglip2MLP(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + # TODO(Isotr0py): Enable data parallel after we support + # disabling TP on parallel linear layer + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`Siglip2EncoderLayer`]. + + Args: + config: PretrainedConfig + """ + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Siglip2EncoderLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + use_data_parallel=use_data_parallel) + for idx in range(config.num_hidden_layers) + ]) + + self.rotary_pos_emb = VisionRotaryEmbedding( + config.hidden_size // config.num_attention_heads // 2) + self.patch_size = config.patch_size + self.hidden_stride = config.hidden_stride + self.window_size = config.window_size + self.spatial_merge_unit = config.hidden_stride * config.hidden_stride + if config.fullatt_block_indexes is None: + self.fullatt_block_indexes = None + else: + self.fullatt_block_indexes = [ + int(i) for i in config.fullatt_block_indexes.split('|') + ] + + # copied from qwen2.5_vl + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + # patch (after merge) number in each window + vit_merger_window_size = (self.window_size // self.hidden_stride // + self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.hidden_stride, # number of patch after merge + grid_w // self.hidden_stride, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + inputs_embeds: torch.Tensor, + grid_thws: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if + you want more control over how to convert `input_ids` indices + into associated vectors than the model's internal embedding + lookup matrix. + grid_thws (`torch.LongTensor`): + grid shape (num_patches, 3) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of + a plain tuple. + """ + rotary_pos_emb = self.rot_pos_emb(grid_thws) + window_index, cu_window_seqlens = self.get_window_index(grid_thws) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=inputs_embeds.device, + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = inputs_embeds.size() + inputs_embeds = inputs_embeds.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + inputs_embeds = inputs_embeds[window_index, :, :] + inputs_embeds = inputs_embeds.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have + # same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 + # for more information + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + reverse_indices = torch.argsort(window_index) + + hidden_states = inputs_embeds + for index, block in enumerate(self.layers): + if (not self.fullatt_block_indexes + or index in self.fullatt_block_indexes): + cu_seqlens_tmp = cu_seqlens + else: + cu_seqlens_tmp = cu_window_seqlens + hidden_states = block(hidden_states, cu_seqlens_tmp, + position_embeddings) + + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) + + return hidden_states + + +class Siglip2VisionTransformer(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + ) -> torch.Tensor: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) + of the input images. + """ + hidden_states = self.embeddings(pixel_values, grid_thws) + + last_hidden_state = self.encoder(hidden_states, grid_thws) + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Siglip2NavitModel(torch.nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + + self.vision_model = Siglip2VisionTransformer( + config, + quant_config=quant_config, + prefix=f"{prefix}.vision_model", + use_data_parallel=use_data_parallel) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + grid_thws=grid_thws, + ) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index c76aabcd27ccb00f284c55ba81a760a84d90fd7c..9857ccdcbe2d43bcee3e67d430477e66f5a7d4a7 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -26,7 +26,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, @@ -48,27 +49,42 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) -class SkyworkR1VImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor +class SkyworkR1VImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + pixel_values_flat: Annotated[ + torch.Tensor, + TensorShape("bnp", 3, "h", "w"), + ] + num_patches: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] -class SkyworkR1VImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class SkyworkR1VImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - ni: Number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("ni", "ifs", "hs"), + ] SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, @@ -552,18 +568,19 @@ class SkyworkR1VMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -730,26 +747,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -787,10 +784,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): return SkyworkR1VImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size, + }) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index d6ec743ce845e30db03a1fb1857541299ffaf2b9..9e880ebd508133a1954d77535c9ff858bf665626 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -22,6 +22,7 @@ """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -247,7 +248,7 @@ class StableLMEpochModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 9d9a2bff0e43fc6e43175b2b9f03b3f10b710e2b..62ff9b6182755eec2894f660c7278dff9e073c91 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -21,6 +21,7 @@ # limitations under the License. """ PyTorch Starcoder2 model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -250,7 +251,7 @@ class Starcoder2Model(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 47d2af5c2a1400a9efebca82e40d34169f5ffb71..97611d3e140ec942c48bb9d450b48dd2f3308557 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jurassic model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch @@ -346,8 +347,7 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 2ddf5031a5bb03b2f141014354cec26259feaada..dac8275d5c1b97d2072c2eb2803933fcef6e14c2 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -520,20 +520,18 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_placeholder_token_id = hf_processor.image_token_id - batch_num_patches = out_mm_kwargs["num_patches"].tolist() def get_replacement_step1o(item_idx: int): - img_out = out_mm_kwargs.get_item("image", item_idx) - num_patches = batch_num_patches[item_idx] + out_item = out_mm_kwargs["image"][item_idx] + num_patches = int(out_item["num_patches"].data) if num_patches > 0: - patch_newline_mask = img_out["patch_newline_mask"].data.tolist( - ) + patch_newline_mask = out_item["patch_newline_mask"].data image_repl_ids = hf_processor._get_image_repl_features( - 1, num_patches, patch_newline_mask)[1] + 1, num_patches, patch_newline_mask.tolist())[1] else: image_repl_ids = hf_processor._get_image_repl_features( 1, 0, None)[1] @@ -869,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "lm_head.": "language_model.lm_head.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -884,8 +884,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Step3VisionTransformer( diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..30b441f5b4df00f84c4e4cd897e8d922d0d4ff59 --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.models.swin.modeling_swin import SwinPatchMerging +from transformers.pytorch_utils import meshgrid + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class SwinSelfAttention(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})") + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = (window_size if isinstance(window_size, Iterable) + else (window_size, window_size)) + self.scale = self.attention_head_size**-0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, + None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter(relative_position_index, + requires_grad=False) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view(-1, + self.num_attention_heads, + dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + return outputs + + +class SwinSelfOutput(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.self = SwinSelfAttention(config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self") + self.output = SwinSelfOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, + output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, ) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = ColumnParallelLinear(dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = RowParallelLinear(int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + return hidden_states + + +class SwinLayer(HFSwinLayer): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + + self.attention = SwinAttention(config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.intermediate = SwinIntermediate(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + self.output = SwinOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + + +class SwinStage(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList([ + SwinLayer(config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if + (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, + dim=dim, + norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + + 1) // 2 + output_dimensions = (height, width, height_downsampled, + width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, + input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, + output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + self.layers = nn.ModuleList([ + SwinStage(config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=(grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx)), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[sum(config.depths[:layer_idx] + ):sum(config.depths[:layer_idx + 1])], + downsample=SwinPatchMerging if + (layer_idx < self.num_layers - 1) else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(self.num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index c8709d866b1e7f64a4b14ad6378d8c662bb510d7..c66867315e5536dbbc5b6df001ef5ae62d27a862 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union, cast) import torch @@ -25,15 +25,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -43,14 +45,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .vision import VisionEncoderInfo, get_vision_encoder_info -class TarsierImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class TarsierImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class TarsierImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor +class TarsierImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] TarsierImageInputs = Union[TarsierImagePixelInputs, @@ -275,7 +291,7 @@ class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index # The <IMAGE> token ID @@ -317,7 +333,7 @@ def _build_tarsier_hf_processor( info: _I_Tarsier, dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, TarsierProcessingInfo): return TarsierMultiModalProcessor( @@ -432,18 +448,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) # Assuming 3 channels - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[TarsierImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -459,8 +463,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, return TarsierImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values, concat=True), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index f3b7263ca38726a9eb7f6481b2c5bc1c65a00452..5ad0482330ecdddc78ee35919f855e3080f297b1 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -17,6 +17,7 @@ """Wrapper around `transformers` models""" from collections.abc import Iterable, Mapping from contextlib import contextmanager +from pathlib import Path from typing import Literal, Optional, Union import regex as re @@ -41,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems @@ -60,6 +61,21 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, logger = init_logger(__name__) +def get_feature_request_tip( + model: str, + trust_remote_code: bool, +) -> str: + hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" + gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" + url = hf_url if trust_remote_code else gh_url + prefix = f"Please open {url} to request support for this feature. " + if Path(model).exists(): + prefix = "" + doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" + tip = f"See {doc_url} for instructions on how to add support yourself." + return f"{prefix}{tip}" + + def vllm_flash_attention_forward( # Transformers args module: torch.nn.Module, @@ -88,9 +104,29 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): logger.debug("%s: %s -> %s", name, old_module, new_module) +def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: + """ + Callable to be passed to `@support_torch_compile`'s `enable_if` argument. + + Defaults to `True` but is disabled in the following situations: + + - The model uses dynamic rope scaling. + """ + enable = True + text_config = vllm_config.model_config.hf_config.get_text_config() + # Dynamic rope scaling is not compatible with torch.compile + rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} + if rope_scaling.get("rope_type") == "dynamic": + enable = False + return enable + + def replace_linear_class( - linear: nn.Linear, style: Literal["colwise", "rowwise"], - quant_config: QuantizationConfig + linear: nn.Linear, + style: Literal["colwise", "rowwise"], + quant_config: QuantizationConfig, + *, + prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. @@ -124,6 +160,7 @@ def replace_linear_class( output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, + prefix=prefix, return_bias=False, **vllm_linear_kwargs, ) @@ -237,7 +274,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ): """ Given the original multi-modal items for this modality @@ -310,7 +347,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -372,14 +409,16 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_tokens_per_modality["num_image_patches"] ) if "num_image_patches" in mm_tokens_per_modality else None processed_data['num_image_patches'] = num_image_patches - mm_kwargs = MultiModalKwargs.from_hf_inputs( + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, num_image_patches), ) + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else + self._hash_mm_items(mm_items, hf_processor_mm_kwargs, + tokenization_kwargs)) - mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs) return MultiModalInputs( type="multimodal", prompt=prompt, @@ -457,8 +496,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): return if not self.model.supports_pp_plan: + tip = get_feature_request_tip(self.model_config.model, + self.model_config.trust_remote_code) raise ValueError( - f"{type(self.model)} does not support pipeline parallel yet!") + f"{type(self.model)} does not support pipeline parallel. {tip}" + ) module_lists = [] module_list_idx = None @@ -512,8 +554,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): models_with_tp_plan = filter(supports_tp_plan, pretrained_models) if not any(models_with_tp_plan) and self.tp_size > 1: + tip = get_feature_request_tip(self.model_config.model, + self.model_config.trust_remote_code) raise ValueError( - f"{type(self.model)} does not support tensor parallel yet!") + f"{type(self.model)} does not support tensor parallel. {tip}") def _tensor_parallel(module: nn.Module, prefix: str = "", @@ -538,8 +582,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) style = tp_plan.get(pattern, "replicate") - new_module = replace_linear_class(child_module, style, - self.quant_config) + new_module = replace_linear_class(child_module, + style, + self.quant_config, + prefix=qual_name) setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) else: @@ -642,7 +688,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersModel(TransformersBase): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -654,7 +700,7 @@ class TransformersModel(TransformersBase): }) -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -709,6 +755,15 @@ def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: MultiModalProcessor, info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder) +@support_torch_compile( + # set `positions` to last dim to support Qwen-mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }, + enable_if=can_enable_torch_compile) class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index bef34c1be49fe12e17a5f33a1cdf559b37718cbe..f91c4ddb6e8342fac2ba80c7a1cb938e9e8a1a01 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -23,7 +23,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -194,7 +194,7 @@ class UltravoxMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -203,7 +203,8 @@ class UltravoxMultiModalProcessor( # Each audio can be split into multiple chunks. # chunks_start_idx[i] indicates the start index of the chunks # belonging to the i-th audio. - num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0)) + out_mm_data = out_mm_kwargs.get_data() + num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0)) chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, dim=0, dtype=torch.int32) @@ -213,7 +214,7 @@ class UltravoxMultiModalProcessor( def get_replacement_ultravox(item_idx: int): start = chunks_start_idx[item_idx] end = chunks_start_idx[item_idx + 1] - audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum() + audio_token_len = out_mm_data["audio_token_len"][start:end].sum() return [replacement_id] * int(audio_token_len) # type: ignore return [ diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6c27fedc61b1744967586c109569fb513146927c..28cfefac30ddb9a010e2814bd26bbecdff08961b 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -507,8 +507,10 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = torch.tensor(placeholder_token_id, - device=input_ids.device) + placeholder_token_id = torch.tensor( + placeholder_token_id, + pin_memory=is_pin_memory_available()).to(device=input_ids.device, + non_blocking=True) return _merge_multimodal_embeddings( inputs_embeds, torch.isin(input_ids, placeholder_token_id), diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 6b06c0ac6683f201b6ab16ee63c586def24b61c3..6bc748407a7d1f8123cf786164998942e4b85e19 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -17,7 +17,7 @@ from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder -from transformers import TensorType, WhisperConfig +from transformers import BatchFeature, TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig @@ -31,11 +31,12 @@ from vllm.model_executor.models.whisper import WhisperEncoder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -155,10 +156,12 @@ class VoxtralProcessorAdapter: audios_tokens.append(torch.tensor(audio_tokens)) audios_processed.append(torch.tensor(audio)) - return { - "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), - "audio_arrays": audios_processed, - } + return BatchFeature({ + "input_ids": + torch.cat(audios_tokens)[None].expand(len(text), -1), + "audio_arrays": + audios_processed, + }) class VoxtralProcessingInfo(BaseProcessingInfo): @@ -259,7 +262,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -287,20 +290,18 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: - prompt_ids, mm_kwargs, mm_hashes, _ = super( - )._cached_apply_hf_processor( + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True def _get_data_parser(self) -> MultiModalDataParser: sampling_rate = self.info.get_hf_processor().sampling_rate diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ca02ecd828ba3f8e731694b0bf55ea8cceda9f47..16bbe2f2010a14698edaaeb41acbf4a79a18f727 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -33,7 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, @@ -728,7 +728,7 @@ class WhisperMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: num_tokens = self.info.get_num_audio_tokens() return [ diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 750ee785026880fdb01349c2fa8931e333716e54..9465308e94e659a589c1d3dbe70cecd289e4d197 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Hashable from fractions import Fraction from typing import Callable, Optional, Union +from weakref import WeakValueDictionary import torch from torch.nn import Parameter -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.utils import _make_synced_weight_loader @@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter): into the parameter when the provided weight loader is called. """ - def __new__(cls, data: torch.Tensor, **kwargs): + def __new__(cls, data: Optional[torch.Tensor], **kwargs): return super().__new__(cls, data=data, requires_grad=False) @@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter): def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): self._assert_and_load(loaded_weight) + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert isinstance(shard_id, str) + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + class _ColumnvLLMParameter(BasevLLMParameter): """ @@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, (PackedColumnParameter, @@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_id = kwargs.get("shard_id") num_heads = kwargs.get("num_heads") + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, (PackedColumnParameter, @@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter): """ def __init__(self, **kwargs): - self.qkv_idxs = {"q": 0, "k": 1, "v": 2} super().__init__(**kwargs) - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: - if isinstance(shard_id, int): - return shard_id - - # if not int, assume shard_id for qkv - # map to int and return - assert isinstance(shard_id, str) - assert shard_id in self.qkv_idxs - return self.qkv_idxs[shard_id] - # For row parallel layers, no sharding needed # load weight into parameter as is def load_row_parallel_weight(self, *args, **kwargs): @@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): pass +class SharedWeightParameter(BasevLLMParameter): + """ + Parameter for weights with many shared tensors across a model + + For example, when applying transforms to the "gate" and "up" partitions of + `MergedColumnParallelLinear`, the transform weights must stay separate + tensors in order to allow for tensor memory sharing between layers. + """ + # global registry for sharing tensors based on passed `data_key` + # this dict holds weaksrefs to avoid memory leak after model cleanup + tensors_registry: WeakValueDictionary = WeakValueDictionary() + + # local container for strong references to shared tensors + # this set compensates for the fact that torch.nn.Parameter + # and Parameter subclasses do not hold reliable references to tensors + local_tensors: set[torch.Tensor] + + # dictionary mapping partition indices to associated parameters + partitions: dict[int, Union[ModelWeightParameter, Parameter]] + + def __new__(cls, **kwargs): + return super().__new__(cls, data=None, **kwargs) + + def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs): + weight_loader: Callable = kwargs.get( + "weight_loader") # type: ignore[assignment] + super().__init__(data=None, weight_loader=weight_loader) + + self.local_tensors = set() + self.partitions = {} + self.kwargs = { + "input_dim": input_dim, + "output_dim": output_dim, + "weight_loader": self._fake_weight_loader + } + + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > 1: + raise NotImplementedError(f"{self.__class__.__name__} does not " + "currently support tensor parallelism") + + def add_partition(self, index: int, data_key: Hashable, *args, **kwargs): + """ + Add a partition to the weight parameter. Partitions whose `data_key` + is the same will share tensor data + + :param index: index of partition to add + :param data_key: hashable key used to key shared tensors + :param *args: arguments for `torch.empty` + :param **kwargs: keyword arguments for `torch.empty` + """ + # load (shared) tensor using `data_key` + if data_key not in self.tensors_registry: + data = torch.empty(*args, **kwargs) + self.tensors_registry[data_key] = data + else: + data = self.tensors_registry[data_key] + + # create associated model parameter + self.partitions[index] = ModelWeightParameter( + data=data, **self.kwargs) # type: ignore[arg-type] + + # hold local reference, since ModelWeightParameter does not + # see https://github.com/pytorch/pytorch/issues/75932 + self.local_tensors.add(data) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_column_parallel_weight( + partition, loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = kwargs.pop("shard_id") + partition_id = self._shard_id_as_int(partition_id) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + + ModelWeightParameter.load_merged_column_weight( + partition, + loaded_weight, + shard_offset=shard_offset, + shard_size=shard_size) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = self._shard_id_as_int(kwargs.pop("shard_id")) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + shard_id = "q" # fake first partition + num_heads = kwargs.get("num_heads") + + ModelWeightParameter.load_qkv_weight( + partition, + loaded_weight, + shard_offset=shard_offset, + shard_size=shard_size, + shard_id=shard_id, + num_heads=num_heads, + ) + + def process_weights_after_loading(self): + for key in self.partitions: + self.partitions[key] = torch.nn.Parameter( + data=self.partitions[key].data, requires_grad=False) + + @property + def data(self): + raise ValueError("Accessing `data` of a " + "`PartitionedModelWeightParameter` is not allowed. " + "Instead, use `get_partition` to get the weight of " + "the particular partition you want to access") + + def _fake_weight_loader(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_weight_shard_id: Optional[Union[str, int]]): + raise ValueError("When loading partition weights of " + f"{self.__class__.__name__}, use methods provided by " + f"{self.__class__.__name__}, not partition loader") + + def permute_param_layout_(param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs) -> BasevLLMParameter: """ @@ -456,4 +596,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, bitblas_tile_size=bitblas_tile_size) - return shard_size, shard_offset \ No newline at end of file + return shard_size, shard_offset diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py deleted file mode 100644 index e6f1ca61dd29115fa9b8c5977620f6b737bc87ae..0000000000000000000000000000000000000000 --- a/vllm/model_executor/pooling_metadata.py +++ /dev/null @@ -1,79 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Any - -import torch - -from vllm.pooling_params import PoolingParams -from vllm.utils import is_pin_memory_available - - -class PoolingMetadata: - """Metadata for pooling operations in the Pooler layer. - - This class holds the necessary information for pooling operations, - providing context for how to perform pooling and other related operations. - - Attributes: - seq_groups: List of (seq_ids, pooling_params). - seq_data: A mapping of sequence ID to additional sequence data. - prompt_lens: List of the lengths of each prompt. - """ - - def __init__( - self, - seq_groups: list[tuple[list[int], PoolingParams]], - seq_data: dict[int, Any], # Specific data related to sequences - prompt_lens: list[int], - ) -> None: - self.seq_groups = seq_groups - self.seq_data = seq_data - self.prompt_lens = prompt_lens - - def __repr__(self) -> str: - return ("PoolingMetadata(" - f"seq_groups={self.seq_groups}, " - f"seq_data={self.seq_data}, " - f"prompt_lens={self.prompt_lens})") - - def __getitem__(self, indices: slice): - return PoolingMetadata( - seq_groups=self.seq_groups[indices], - seq_data=dict(list(self.seq_data.items())[indices]), - prompt_lens=self.prompt_lens[indices], - ) - - -@dataclass -class PoolingTensors: - """Tensors for pooling.""" - - prompt_lens: torch.Tensor - - @classmethod - def from_pooling_metadata( - cls, - pooling_metadata: "PoolingMetadata", - device: torch.device, - ) -> "PoolingTensors": - """ - Create PoolingTensors from PoolingMetadata. - - Args: - pooling_metadata: PoolingMetadata instance to convert. - device: Device to store the tensors. - """ - # Convert prompt lengths to tensor - pin_memory = is_pin_memory_available() - - prompt_lens_t = torch.tensor( - pooling_metadata.prompt_lens, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - - return cls(prompt_lens=prompt_lens_t.to(device=device, - non_blocking=True), ) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 603848489d443b0780f9c755cc98138bc401fd59..abda90c292a96afa548f8cb83e012e0687a8794a 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .base import MultiModalPlaceholderMap -from .hasher import MultiModalHashDict, MultiModalHasher +from .hasher import MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict, NestedTensors) + MultiModalKwargsItems, MultiModalPlaceholderDict, + MultiModalUUIDDict, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -22,11 +23,12 @@ __all__ = [ "ModalityData", "MultiModalDataBuiltins", "MultiModalDataDict", - "MultiModalHashDict", "MultiModalHasher", "MultiModalKwargs", + "MultiModalKwargsItems", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", + "MultiModalUUIDDict", "NestedTensors", "MULTIMODAL_REGISTRY", "MultiModalRegistry", diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 8c4136e06f8186ca32f7dc6985babc1b30bebee8..35b743ed21d92c4e94a753f613fbf1585e729b23 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -1,35 +1,83 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import sys -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TypeVar, Union +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union import torch +from typing_extensions import TypeAlias, override from vllm.logger import init_logger from vllm.utils import GiB_bytes, LRUCache -from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves +from vllm.utils.jsontree import (json_count_leaves, json_map_leaves, + json_reduce_leaves) -from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors +from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem, + MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItems, NestedTensors) + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + + from .processing import ResolvedPromptUpdate + from .registry import MultiModalRegistry logger = init_logger(__name__) -@dataclass -class MultiModalCacheItemMetadata: - size: int +class MultiModalProcessorCacheItem: + """ + The data to store inside `MultiModalProcessorOnlyCache`. - @classmethod - def wraps(cls, value: "MultiModalCacheValue"): - return cls(size=MultiModalCache.get_item_size(value)) + Args: + item: The processed tensor data corresponding to a multi-modal item. + prompt_updates: The prompt updates corresponding to `item`. + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item = item + self.prompt_updates = prompt_updates + + +class MultiModalProcessorCacheItemMetadata: + """ + The metadata to store inside `MultiModalProcessorSenderCache`. + + Args: + item: The processed tensor data corresponding to a multi-modal item. + Since P1 already stores the tensor data, we only store its size + metadata in P0 to reduce memory usage. The size metadata is still + needed to keep the same cache eviction policy as P0. + prompt_updates: The prompt updates corresponding to `item`. + This needs to stay on P0 because for some models, they are + dependent on the processed tensor data (cached on P1). + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item_size = MultiModalCache.get_item_size(item) + self.prompt_updates = prompt_updates MultiModalCacheValue = Union[ - MultiModalKwargs, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + MultiModalKwargsItems, MultiModalKwargsItem, + MultiModalKwargs, Mapping[str, NestedTensors], - MultiModalCacheItemMetadata, ] _V = TypeVar("_V", bound=MultiModalCacheValue) @@ -44,22 +92,26 @@ class MultiModalCache: *, debug: bool = False, ) -> int: - # MultiModalKwargs is not a subclass of dict - if isinstance(leaf, MultiModalKwargs): - return cls.get_item_size(leaf.get_data(), debug=debug) + if isinstance(leaf, MultiModalProcessorCacheItem): + return cls.get_leaf_size(leaf.item) + if isinstance(leaf, MultiModalProcessorCacheItemMetadata): + return leaf.item_size - # MultiModalKwargsItem is not a subclass of dict + # These are not subclasses of dict + if isinstance(leaf, MultiModalKwargsItems): + return cls.get_item_size(leaf.data) # type: ignore if isinstance(leaf, MultiModalKwargsItem): - leaf_data = {k: v.data for k, v in leaf.items()} - return cls.get_item_size(leaf_data, debug=debug) + return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalKwargs): + return cls.get_item_size(leaf.data) # type: ignore + + if isinstance(leaf, MultiModalFieldElem): + return cls.get_item_size(leaf.data) # type: ignore # sys.getsizeof doesn't work for tensors if isinstance(leaf, torch.Tensor): return leaf.nbytes - if isinstance(leaf, MultiModalCacheItemMetadata): - return leaf.size - return sys.getsizeof(leaf) @classmethod @@ -76,11 +128,32 @@ class MultiModalCache: ) if debug: - logger.debug("Calculated size of %s to be %.2f GiB", type(value), - size / GiB_bytes) + leaf_count = json_count_leaves(value) + logger.debug( + "Calculated size of %s to be %.2f GiB (%d leaves)", + type(value), + size / GiB_bytes, + leaf_count, + ) return size + @classmethod + def get_item_complexity(cls, value: MultiModalCacheValue) -> int: + """ + Get the number of leaf elements in a multi-modal cache value. + + This provides a measure of structural complexity that can be useful + for debugging cache performance and understanding data patterns. + + Args: + value: The multi-modal cache value to analyze. + + Returns: + The number of leaf elements in the nested structure. + """ + return json_count_leaves(value) + @classmethod def get_lru_cache( cls, @@ -93,3 +166,342 @@ class MultiModalCache: GiB_bytes * capacity_gb, getsizeof=lambda x: cls.get_item_size(x, debug=debug), ) + + +_I = TypeVar("_I", contravariant=True) +_O = TypeVar("_O", covariant=True) + + +class BaseMultiModalCache(ABC, Generic[_I, _O]): + """ + Abstract base class to read/write multi-modal items from cache. + + The idea of multi-modal caching is based on having a client and server + where the client executes in the frontend process (=P0) and + the server in the core process (=P1). The data flow is as follows: + + ``` + is_cached() x N get_and_update() + P0: From API -----------------> -----------------> To P1 + + get_and_update() + P1: From P0 -----------------> To model + ``` + + `is_cached()` can be called any number of times in P0. However, + `get_and_update()` must be called in P0 and P1 one after another + so that their cache eviction order remains the same. + + This ensures that the keys in P0 and P1 caches are mirrored, + allowing us to determine whether a key is cached in P1 by looking + up the P0 cache, without having to communicate with P1. + """ + + @abstractmethod + def get_and_update_item( + self, + mm_item: _I, + mm_hash: str, + ) -> _O: + """ + Possibly update a multi-modal item based on whether it is + in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_item: The multi-modal item to update. + mm_hash: The hash of `mm_item`. + + Returns: + The update multi-modal item. + """ + raise NotImplementedError + + def get_and_update( + self, + mm_items: Sequence[_I], + mm_hashes: list[str], + ) -> list[_O]: + """ + Possibly update a sequence of multi-modal items based on whether they + are in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_items: The multi-modal items to update. + mm_hashes: The hash of each item in `mm_items`. + + Returns: + A new list of updated multi-modal items. + """ + assert len(mm_items) == len(mm_hashes) + + return [ + self.get_and_update_item(mm_item, mm_hash) + for mm_item, mm_hash in zip(mm_items, mm_hashes) + ] + + @abstractmethod + def clear_cache(self) -> None: + """Clear the underlying cache.""" + raise NotImplementedError + + +MultiModalProcessorCacheInItem: TypeAlias = \ + Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]] + + +MultiModalProcessorCacheOutItem: TypeAlias = \ + tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]] + + +class BaseMultiModalProcessorCache( + BaseMultiModalCache[MultiModalProcessorCacheInItem, + MultiModalProcessorCacheOutItem]): + """The required interface for caches on P0.""" + + @abstractmethod + def is_cached_item(self, mm_hash: str) -> bool: + """ + Check whether a multi-modal item is + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hash: The hash of the item to check. + + Returns: + `True` if the item is cached, otherwise `False`. + """ + raise NotImplementedError + + def is_cached(self, mm_hashes: list[str]) -> list[bool]: + """ + Check whether a sequence of multi-modal items are + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hashes: The hash of each item to check. + + Returns: + For each item, `True` if the item is cached, otherwise `False`. + """ + return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + + +class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is disabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes + tensor data and metadata) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItem, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item.item, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the metadata of that item so + that the eviction policy remains the same as the cache on P1, + and return the input. + By only storing the metadata, we avoid keeping the data itself in + memory inside P0. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItemMetadata, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return None, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def _enable_processor_cache( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +) -> bool: + if not mm_registry.supports_multimodal_inputs(model_config): + return False + + mm_config = model_config.get_multimodal_config() + return mm_config.mm_processor_cache_gb > 0 + + +def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: + parallel_config = vllm_config.parallel_config + supports_ipc_cache = (parallel_config.data_parallel_size == 1 + or parallel_config.data_parallel_external_lb) + + return supports_ipc_cache + + +def processor_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalProcessorCache]: + """Return a `BaseMultiModalProcessorCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return MultiModalProcessorOnlyCache(model_config) + + return MultiModalProcessorSenderCache(model_config) + + +def processor_only_cache_from_config( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +): + """Return a `MultiModalProcessorOnlyCache`, if enabled.""" + if not _enable_processor_cache(model_config, mm_registry): + return None + + return MultiModalProcessorOnlyCache(model_config) + + +class BaseMultiModalReceiverCache( + BaseMultiModalCache[Optional[MultiModalKwargsItem], + MultiModalKwargsItem]): + """The required interface for caches on P1.""" + + def get_and_update_features( + self, + mm_features: list["MultiModalFeatureSpec"], + ) -> list["MultiModalFeatureSpec"]: + """Update multimodal features with cached encoder outputs.""" + for feature in mm_features: + feature.data = self.get_and_update_item(feature.data, + feature.identifier) + return mm_features + + +class MultiModalReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 when IPC caching is enabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes tensor + data) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalKwargsItem, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = mm_item + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalReceiverCache]: + """Return a `BaseMultiModalReceiverCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + return MultiModalReceiverCache(model_config) diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index c9ce1f0be5f885ffe45b396701aef11269fd4231..da019d40a6fe4723c929addfc333c323945f1797 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -3,7 +3,7 @@ import pickle import uuid -from collections.abc import Iterable, Mapping +from collections.abc import Iterable from typing import Union import numpy as np @@ -16,11 +16,6 @@ from vllm.multimodal.image import convert_image_mode logger = init_logger(__name__) -MultiModalHashDict = Mapping[str, list[str]] -""" -A dictionary containing hashes for items in each modality. -""" - class MultiModalHasher: @@ -43,7 +38,25 @@ class MultiModalHasher: return cls.item_to_bytes( "image", np.asarray(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): - return cls.item_to_bytes("tensor", obj.numpy()) + tensor_obj: torch.Tensor = obj.cpu() + tensor_dtype = tensor_obj.dtype + tensor_shape = tensor_obj.shape + + # NumPy does not support bfloat16. + # Workaround: View the tensor as a contiguous 1D array of bytes + if tensor_dtype == torch.bfloat16: + tensor_obj = tensor_obj.contiguous() + tensor_obj = tensor_obj.view( + (tensor_obj.numel(), )).view(torch.uint8) + + return cls.item_to_bytes( + "tensor", { + "original_dtype": str(tensor_dtype), + "original_shape": tuple(tensor_shape), + "data": tensor_obj.numpy(), + }) + + return cls.item_to_bytes("tensor", tensor_obj.numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index d3f57cf5338d5665a099b8e2b7b2050caa5492c5..f8ea3835f049d5f90d5a4541bf5eca43dbcefcbc 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,11 +7,11 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, + cast, final) import numpy as np -from typing_extensions import NotRequired, TypeAlias +from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated from vllm.utils import LazyLoader, full_groupby, is_list_of from vllm.utils.jsontree import JSONTree, json_map_leaves @@ -22,7 +22,8 @@ if TYPE_CHECKING: from PIL.Image import Image from transformers.feature_extraction_utils import BatchFeature - from .hasher import MultiModalHashDict + from .processing import MultiModalHashes + else: torch = LazyLoader("torch", globals(), "torch") @@ -115,6 +116,16 @@ The built-in modalities are defined by [`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins]. """ +MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]] +""" +A dictionary containing user-provided UUIDs for items in each modality. +If a UUID for an item is not provided, its entry will be `None` and +MultiModalHasher will compute a hash for the item. + +The UUID will be used to identify the item for all caching purposes +(input processing caching, embedding caching, prefix caching, etc). +""" + @dataclass(frozen=True) class PlaceholderRange: @@ -198,6 +209,29 @@ A dictionary containing nested tensors which have been batched via """ +@dataclass +class MultiModalFeatureSpec: + """ + Represents a single multimodal input with its processed data and metadata. + + Used by the V1 engine to track multimodal data through processing and + caching. A request containing multiple multimodal items will have one + MultiModalFeatureSpec per item. + """ + + data: Optional["MultiModalKwargsItem"] + """Multimodal data for this feature""" + + modality: str + """Based on the input, e.g., "image", "audio", "video".""" + + identifier: str + """mm_hash or uuid for caching encoder outputs.""" + + mm_position: PlaceholderRange + """e.g., PlaceholderRange(offset=2, length=336)""" + + @dataclass class MultiModalFieldElem: """ @@ -656,7 +690,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None: super().__init__(data) - modalities = {elem.modality for elem in self.data.values()} + modalities = {elem.modality for elem in self.values()} assert len(modalities) == 1, f"Found different modalities={modalities}" self._modality = next(iter(modalities)) @@ -664,20 +698,23 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): def modality(self) -> str: return self._modality - def get_data(self) -> Mapping[str, NestedTensors]: + def get_data(self) -> dict[str, NestedTensors]: return {key: elem.data for key, elem in self.items()} -class MultiModalKwargs: - """ - A dictionary that represents the keyword arguments to - [`torch.nn.Module.forward`][]. +_I = TypeVar( + "_I", + MultiModalKwargsItem, + Optional[MultiModalKwargsItem], + default=MultiModalKwargsItem, +) - The metadata `items` enables us to obtain the keyword arguments - corresponding to each data item in - [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via - [`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and - [`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items]. + +class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): + """ + A dictionary of + [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s + by modality. """ @staticmethod @@ -712,19 +749,74 @@ class MultiModalKwargs: elems = [v[item_idx] for v in elems_in_modality.values()] items.append(MultiModalKwargsItem.from_elems(elems)) - return MultiModalKwargs(items) - - def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None: - super().__init__() + return MultiModalKwargsItems.from_seq(items) + @staticmethod + def from_seq(items: Sequence[MultiModalKwargsItem]): items_by_modality = full_groupby(items, key=lambda x: x.modality) - self._items_by_modality = dict(items_by_modality) + return MultiModalKwargsItems(items_by_modality) - self._data: Optional[Mapping[str, NestedTensors]] = None + def __getitem__(self, modality: str) -> Sequence[_I]: + if modality not in self: + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {set(self.keys())}") - @property - def modalities(self): - return self._items_by_modality.keys() + return super().__getitem__(modality) # type: ignore[return-value] + + def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError("Cannot build data from empty " + f"mm_items[{modality}][{i}]") + + for key, elem in item.items(): + elems_by_key[key].append(elem) + + return MultiModalKwargs({ + key: + elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() + }) + + +MultiModalKwargsOptionalItems: TypeAlias = Union[ + MultiModalKwargsItems[MultiModalKwargsItem], + MultiModalKwargsItems[Optional[MultiModalKwargsItem]], +] + + +class MultiModalKwargs(UserDict[str, NestedTensors]): + """ + A dictionary that represents the keyword arguments to + [`torch.nn.Module.forward`][]. + """ + + @staticmethod + @deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_hf_inputs` and " + "access the tensor data using `.get_data()`.") + def from_hf_inputs( + hf_inputs: "BatchFeature", + config_by_key: Mapping[str, MultiModalFieldConfig], + ): + return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \ + .get_data() + + @staticmethod + @deprecated("`MultiModalKwargs.from_items` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_seq` and " + "access the tensor data using `.get_data()`.") + def from_items( + items: Sequence[MultiModalKwargsItem], + *, + pin_memory: bool = False, + ): + return MultiModalKwargsItems.from_seq(items) \ + .get_data(pin_memory=pin_memory) @staticmethod def _try_stack(nested_tensors: NestedTensors, @@ -813,92 +905,24 @@ class MultiModalKwargs: return cast(BatchedTensorInputs, json_mapped) - def keys(self): - return self.get_data().keys() - - def values(self): - return self.get_data().values() - - def items(self): - return self.get_data().items() - - def get(self, key: str, /, default=None): - return self.get_data().get(key, default) - - def pop(self, key: str, *args, **kwargs): - data = dict(self.get_data()) - res = data.pop(key, *args, **kwargs) - - for items in self._items_by_modality.values(): - for item in items: - item.pop(key, *args, **kwargs) - - self._data = None - - return res - - def __iter__(self): - return iter(self.get_data()) - def __getitem__(self, key: str): - return self.get_data()[key] + if key not in self: + raise KeyError(f"Keyword argument {key!r} not found. " + f"Available keys: {set(self.keys())}") + + return super().__getitem__(key) def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - return self._items_by_modality == other._items_by_modality - - def _validate_modality(self, method_name: str, modality: str) -> None: - if not self._items_by_modality: - raise RuntimeError( - f"`{method_name}` is not supported when " - "MultiModalKwargs is not initialized with `items`") - - if modality not in self._items_by_modality: - available_modalities = set(self._items_by_modality.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") - - def get_item_count(self, modality: str) -> int: - """Get the number of items belonging to a modality.""" - self._validate_modality("get_item_count", modality) - return len(self._items_by_modality[modality]) + for k in self: + if k not in other: + return False + if not nested_tensors_equal(self[k], other[k]): + return False - def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem: - """ - Get the keyword arguments corresponding to an item identified by - its modality and index. - """ - self._validate_modality("get_item", modality) - return self._items_by_modality[modality][item_index] - - def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: - """ - Get the keyword arguments corresponding to each item belonging to - a modality. - """ - self._validate_modality("get_items", modality) - return self._items_by_modality[modality] - - def get_data(self, - *, - pin_memory: bool = False) -> Mapping[str, NestedTensors]: - if self._data is not None: - return self._data - - elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) - for items in self._items_by_modality.values(): - for item in items: - for key, elem in item.items(): - elems_by_key[key].append(elem) - - data = { - key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() if len(elems) > 0 - } - self._data = data - return data + return True MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] @@ -923,13 +947,10 @@ class MultiModalInputs(TypedDict): prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" - token_type_ids: NotRequired[list[int]] - """The token type IDs of the prompt.""" - - mm_kwargs: MultiModalKwargs + mm_kwargs: MultiModalKwargsOptionalItems """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: Optional["MultiModalHashDict"] + mm_hashes: "MultiModalHashes" """The hashes of the multi-modal data.""" mm_placeholders: "MultiModalPlaceholderDict" @@ -956,6 +977,3 @@ class MultiModalEncDecInputs(MultiModalInputs): encoder_prompt_token_ids: list[int] """The processed token IDs of the encoder prompt.""" - - encoder_token_type_ids: NotRequired[list[int]] - """The token type IDs of the encoder prompt.""" diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 99c341ba3cd740af52a030e86a96c1d243b3bbb8..9a9cb5dc97aa7f150507dcb9abbf105cd5436b79 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -16,7 +16,7 @@ from vllm.utils import LazyLoader, is_list_of from .audio import AudioResampler from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, ImageItem, ModalityData, MultiModalDataDict, - MultiModalFieldConfig, MultiModalKwargs, VideoItem) + MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) _T = TypeVar("_T") _I = TypeVar("_I") @@ -157,19 +157,16 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], self.fields_config = fields_config self.required_fields = required_fields - self._kwargs = MultiModalKwargs.from_hf_inputs( + self._kwargs = MultiModalKwargsItems.from_hf_inputs( BatchFeature(dict(data)), fields_config, ) def get_count(self) -> int: - return self._kwargs.get_item_count(self.modality) + return len(self._kwargs[self.modality]) def get(self, index: int) -> Mapping[str, torch.Tensor]: - return { - k: v.data - for k, v in self._kwargs.get_item(self.modality, index).items() - } + return self._kwargs[self.modality][index].get_data() def get_processor_data(self) -> Mapping[str, object]: return {} diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 4684bf6f3d83a8e88b573708d0f1ef4cc580d484..0531b7bd9f0a721994628f46b2894ede3eb24c67 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, Sequence) -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, @@ -20,11 +20,12 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) from vllm.utils import flatten_2d_lists, full_groupby -from .cache import MultiModalCache from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, - MultiModalKwargsItem, PlaceholderRange) + MultiModalFieldConfig, MultiModalInputs, + MultiModalKwargsItem, MultiModalKwargsItems, + MultiModalKwargsOptionalItems, MultiModalUUIDDict, + PlaceholderRange) from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser) @@ -33,6 +34,7 @@ if TYPE_CHECKING: from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -43,10 +45,59 @@ PromptSeq = Union[str, list[int]] """A token sequence (list of token IDs) or text.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + return encode_tokens(tokenizer, + text, + add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + return decode_tokens(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: + if isinstance(seq, str): + return seq + + return _cached_decode(tokenizer, tuple(seq)) + + +def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: + if isinstance(seq, str): + return _cached_encode(tokenizer, seq, add_special_tokens=False) + + return seq + + +class _GetMatchIndex(Protocol): + + def __call__( + self, + tokenizer: AnyTokenizer, + prompt: PromptSeq, + start_idx: int = 0, + ) -> Optional[int]: + ... + + @dataclass class PromptIndex: """Resolves to an index in the prompt.""" - get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + get_match_index: _GetMatchIndex class PromptIndexTargets: @@ -58,7 +109,7 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: 0) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0) @staticmethod def prefix(seq: PromptSeq) -> PromptIndex: @@ -69,7 +120,11 @@ class PromptIndexTargets: def get_match_index( tokenizer: AnyTokenizer, prompt: PromptSeq, + start_idx: int = 0, ) -> Optional[int]: + if start_idx != 0: + return None + prefix = seq if isinstance(prompt, str): @@ -95,14 +150,24 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: len(prompt)) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt)) -PromptTarget = Union[PromptSeq, PromptIndex] +UpdateTarget = Union[PromptSeq, PromptIndex] """ The token sequence or text to update. """ +PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget] +""" +Given the index of the processed item within +[`modality`][vllm.multimodal.processing.PromptUpdate.modality], +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + @dataclass class PromptUpdateDetails(Generic[_S]): @@ -111,7 +176,8 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None + is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], + torch.Tensor]] = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -133,11 +199,12 @@ class PromptUpdateDetails(Generic[_S]): embed_text: str, ) -> "PromptUpdateDetails[_S]": - def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_ids = encode_tokens(full.tokenizer, embed_text) + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + embed_token_ids = encode_tokens(tokenizer, embed_text) + token_ids = _seq2tokens(tokenizer, full) return torch.isin( - torch.tensor(full.token_ids), + torch.tensor(token_ids), torch.tensor(embed_token_ids), ) @@ -148,10 +215,13 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - return PromptUpdateDetails( - full=seq, - is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, - ) + + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + token_ids = _seq2tokens(tokenizer, full) + + return torch.tensor(token_ids) == embed_token_id + + return PromptUpdateDetails(full=seq, is_embed=is_embed) PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] @@ -189,7 +259,7 @@ class PromptUpdate(ABC): modality: str """The modality for which the update is made.""" - target: PromptTarget + target: PromptUpdateTarget """The token sequence (or text) to update.""" @property @@ -204,10 +274,35 @@ class PromptUpdate(ABC): """Defines how to update the prompt.""" raise NotImplementedError - def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": - return BoundPromptUpdate( - _origin=self, - tokenizer=tokenizer, + def _resolve_target(self, item_idx: int) -> UpdateTarget: + target = self.target + if callable(target): + target = target(item_idx) + + return target + + def _resolve_content(self, item_idx: int) -> PromptUpdateDetails: + content = self.content + if callable(content): + content = content(item_idx) + + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + return content + + def resolve(self, item_idx: int) -> "ResolvedPromptUpdate": + """ + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], + output a copy of this object with its lazy attributes resolved. + """ + return ResolvedPromptUpdate( + modality=self.modality, + item_idx=item_idx, + mode=self.mode, + target=self._resolve_target(item_idx), + content=self._resolve_content(item_idx), ) @@ -354,30 +449,6 @@ class PromptReplacement(PromptUpdate): return UpdateMode.REPLACE -@lru_cache(maxsize=2048) -def _cached_encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: Optional[bool] = None, -) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_decode( - tokenizer: AnyTokenizer, - token_ids: tuple[int, ...], - *, - skip_special_tokens: Optional[bool] = None, -) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) - - class _HasModalityAttr(Protocol): modality: str @@ -398,126 +469,103 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: return full_groupby(values, key=lambda x: x.modality) -@dataclass -class _BoundPromptSequence: - """ - A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound - to a tokenizer to automatically - convert between token sequence and text representations. - """ - tokenizer: AnyTokenizer = field(repr=False) - - _text: Optional[str] - _token_ids: Optional[list[int]] +class PromptTargetMatch(NamedTuple): + start_idx: int + end_idx: int - @staticmethod - def from_seq( - tokenizer: AnyTokenizer, - seq: PromptSeq, - ) -> "_BoundPromptSequence": - return _BoundPromptSequence( - tokenizer=tokenizer, - _text=seq if isinstance(seq, str) else None, - _token_ids=seq if isinstance(seq, list) else None, - ) - def __post_init__(self) -> None: - if self._text is None and self._token_ids is None: - raise ValueError("At least one of 'text' and 'token_ids' must be " - "specified") +@dataclass(frozen=True) +class ResolvedPromptUpdate: + """ + A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its + lazy attributes resolved, apart from those related to tokenization. + """ - @property - def text(self) -> str: - if self._text is None: - assert self._token_ids is not None - self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) + modality: str + """The modality for which the update is made.""" - return self._text + item_idx: int + """The index within `modality` of the item this update pertains to.""" - @property - def token_ids(self) -> list[int]: - if self._token_ids is None: - assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, - self._text, - add_special_tokens=False) + mode: UpdateMode + """Defines how to update the prompt.""" - return self._token_ids + target: UpdateTarget + """The token sequence (or text) to update.""" + content: PromptUpdateDetails = field(repr=False) + """The placeholder tokens that are part of the update.""" -@dataclass -class _BoundPromptContent: - full: _BoundPromptSequence - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] + def iter_token_matches( + self, + prompt: list[int], + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) -@dataclass -class BoundPromptUpdate: - """ - A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound - to a tokenizer to automatically convert - [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of - [`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content] - between token sequence and text representations. - """ - _origin: PromptUpdate - tokenizer: AnyTokenizer = field(repr=False) + return - def __post_init__(self) -> None: - self._content_cache = dict[int, _BoundPromptContent]() + target_token_ids = _seq2tokens(tokenizer, target) - @property - def modality(self) -> str: - return self._origin.modality + for match in iter_token_matches(prompt, + target_token_ids, + start_idx=start_idx): + yield PromptTargetMatch(match.start_idx, match.end_idx) - @property - def target(self) -> Union[_BoundPromptSequence, PromptIndex]: - """The token sequence (or text) to update.""" - target = self._origin.target + def iter_text_matches( + self, + prompt: str, + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target if isinstance(target, PromptIndex): - return target + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) - return _BoundPromptSequence.from_seq(self.tokenizer, target) + return - @property - def content(self) -> PromptUpdateContent: - """The placeholder tokens that are part of the update.""" - return self._origin.content + target_text = _seq2text(tokenizer, target) - @property - def mode(self) -> UpdateMode: - """Defines how to update the prompt.""" - return self._origin.mode + for match in re.finditer(re.escape(target_text), prompt, + pos=start_idx): + yield PromptTargetMatch(match.start(), match.end()) - def get_content(self, item_idx: int) -> _BoundPromptContent: - """ - Given the index of the processed item within - [`modality`][vllm.multimodal.processing.PromptUpdate.modality], - output the token sequence (or text) to update. - """ - content = self.content - if callable(content): - cache_key = item_idx - if cache_key in self._content_cache: - return self._content_cache[cache_key] + def iter_matches( + self, + prompt: Union[list[int], str], + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + if isinstance(prompt, str): + return self.iter_text_matches(prompt, + tokenizer, + start_idx=start_idx) - content = content(item_idx) - else: - cache_key = None + return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) + def with_target(self, target: UpdateTarget): + return replace(self, target=target) + + def with_content(self, content: PromptUpdateInfo): if not isinstance(content, PromptUpdateDetails): content = PromptUpdateDetails.from_seq(content) - bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - content.full) - bound_content = _BoundPromptContent(full=bound_full, - is_embed=content.is_embed) - - if cache_key is not None: - self._content_cache[cache_key] = bound_content - - return bound_content + return replace(self, content=content) class _TokenMatch(NamedTuple): @@ -528,6 +576,8 @@ class _TokenMatch(NamedTuple): def iter_token_matches( token_ids: list[int], match_ids: list[int], + *, + start_idx: int = 0, ) -> Generator[_TokenMatch]: """ Yield each occurrence of `match_ids` in `token_ids`. @@ -540,7 +590,6 @@ def iter_token_matches( if match_len == 0: return - start_idx = 0 while start_idx < prompt_len - match_len + 1: end_idx = start_idx + match_len @@ -580,68 +629,6 @@ def replace_token_matches( return flatten_2d_lists(out_seqs) -@dataclass(repr=False) -class PromptTargetMatch(ABC): - _origin: BoundPromptUpdate - - @property - def modality(self) -> str: - return self._origin.modality - - @property - @abstractmethod - def start_idx(self) -> int: - raise NotImplementedError - - @property - @abstractmethod - def end_idx(self) -> int: - raise NotImplementedError - - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") - - -@dataclass(repr=False) -class _PromptTargetIndexMatch(PromptTargetMatch): - match_idx: int - - @property - def start_idx(self) -> int: - return self.match_idx - - @property - def end_idx(self) -> int: - return self.match_idx - - -@dataclass(repr=False) -class _PromptTargetTokenMatch(PromptTargetMatch): - match: _TokenMatch - - @property - def start_idx(self) -> int: - return self.match.start_idx - - @property - def end_idx(self) -> int: - return self.match.end_idx - - -@dataclass(repr=False) -class _PromptTargetTextMatch(PromptTargetMatch): - match: re.Match[str] - - @property - def start_idx(self) -> int: - return self.match.start() - - @property - def end_idx(self) -> int: - return self.match.end() - - @dataclass class PlaceholderFeaturesInfo: modality: str @@ -664,163 +651,161 @@ class PlaceholderFeaturesInfo: ) -def find_token_matches( - prompt: list[int], - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" - - def get_matches(update: BoundPromptUpdate): - target = update.target - - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] - - return [_PromptTargetIndexMatch(update, match_idx)] - - return [ - _PromptTargetTokenMatch(update, match) - for match in iter_token_matches(prompt, target.token_ids) - ] - - return [ - match for update in prompt_updates for match in get_matches(update) - ] - - -def find_text_matches( - prompt: str, - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" - - def get_matches(update: BoundPromptUpdate): - target = update.target - - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] - - return [_PromptTargetIndexMatch(update, match_idx)] - - return [ - _PromptTargetTextMatch(update, match) - for match in re.finditer(re.escape(target.text), prompt) - ] - - return [ - match for update in prompt_updates for match in get_matches(update) - ] - - -def _resolve_matches( - prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], -) -> list[PromptTargetMatch]: - """ - Resolve `mm_matches` to ensure that there are no overlapping matches, - and sort them such that earlier matches take priority over later ones. - """ - matches = [m for matches in mm_matches.values() for m in matches] - - seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) - - for match in matches: - for idx in range(match.start_idx, match.end_idx): - if seen_matches[idx] is not None: - raise ValueError("Found overlapping matches " - f"({seen_matches[idx]} and {match}) " - f"at index={idx} of prompt={prompt}") +_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] - seen_matches[idx] = match - return sorted(matches, key=lambda x: x.start_idx) +def _find_matches( + prompt: _S, + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, + *, + prev_end_idx: int = 0, + current_result: "MultiModalPromptUpdatesApplyResult", +) -> tuple[Optional[UpdateMode], list[_MatchToApply]]: + mode: Optional[UpdateMode] = None + mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]() + + for modality, modality_updates in mm_prompt_updates.items(): + for item_idx, item_updates in enumerate(modality_updates): + if current_result[modality][item_idx] is not None: + continue # Updates have already been applied for this item + + for update_idx, update in enumerate(item_updates): + if (modality, item_idx) in mm_matches: + break # Already found a match for this item + + for match in update.iter_matches( + prompt, + tokenizer, + start_idx=prev_end_idx, + ): + # All matches should share the same mode + if mode is None: + mode = update.mode + elif mode != update.mode: + continue + + mm_matches[(modality, item_idx)] = match, update_idx + break # Get only the first valid match per item + + # Prioritize earlier matches + matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0]) + + # To avoid conflicts, only replace one non-empty item at a time + if mode == UpdateMode.REPLACE: + matches_to_apply_ = list[_MatchToApply]() + has_non_empty_matches = False + + for item in matches_to_apply: + _, (match, _) = item + if match.start_idx == match.end_idx: + matches_to_apply_.append(item) + elif not has_non_empty_matches: + has_non_empty_matches = True + matches_to_apply_.append(item) + + matches_to_apply = matches_to_apply_ + + return mode, matches_to_apply def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[_S]: - """Apply the updates in `mm_matches` to `prompt`.""" + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: + prompt_len = len(prompt) + out_seqs = list[Union[str, list[int]]]() - prev_end_idx = 0 - next_idx_by_modality = defaultdict[str, int](lambda: 0) + out_result: MultiModalPromptUpdatesApplyResult = { + m: [None] * len(items) + for m, items in mm_prompt_updates.items() + } - for match in _resolve_matches(prompt, mm_matches): - modality = match.modality + start_idx = prev_end_idx = 0 + while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt + found = False - item_start_idx = next_idx_by_modality[modality] - max_item_count = mm_item_counts.get(modality, 0) - if item_start_idx >= max_item_count: - continue + mode, matches_to_apply = _find_matches( + prompt, + mm_prompt_updates, + tokenizer, + prev_end_idx=prev_end_idx, + current_result=out_result, + ) - start_idx = match.start_idx - end_idx = match.end_idx - origin = match._origin - mode = origin.mode - - if mode == UpdateMode.INSERT: - out_seqs.append(prompt[prev_end_idx:end_idx]) - num_inserts = max_item_count - elif mode == UpdateMode.REPLACE: - out_seqs.append(prompt[prev_end_idx:start_idx]) - num_inserts = max_item_count if start_idx == end_idx else 1 - else: - assert_never(mode) + if mode is not None: + for (modality, item_idx), (match, update_idx) in matches_to_apply: + found = True - item_end_idx = min(item_start_idx + num_inserts, max_item_count) + matched_update = mm_prompt_updates[modality][item_idx][ + update_idx] + matched_content = matched_update.content.full - for item_idx in range(item_start_idx, item_end_idx): - content = origin.get_content(item_idx) - insert_seq = (content.full.text if isinstance(prompt, str) else - content.full.token_ids) + if mode == UpdateMode.INSERT: + end_idx_to_insert = match.end_idx + elif mode == UpdateMode.REPLACE: + end_idx_to_insert = match.start_idx + else: + assert_never(mode) - out_seqs.append(insert_seq) + out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) + out_seqs.append( + _seq2text(tokenizer, matched_content + ) if isinstance(prompt, str) else _seq2tokens( + tokenizer, matched_content)) + out_result[modality][item_idx] = update_idx - prev_end_idx = end_idx - next_idx_by_modality[modality] += item_end_idx - item_start_idx + # Exclude overlapping matches + start_idx = prev_end_idx = match.end_idx + + if not found: + start_idx += 1 out_seqs.append(prompt[prev_end_idx:]) - return cast(list[_S], out_seqs) + return cast(list[_S], out_seqs), out_result def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[int]: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, + tokenizer) - return flatten_2d_lists(token_id_seqs) + return flatten_2d_lists(token_id_seqs), result def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> str: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - texts = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer) - return "".join(texts) + return "".join(texts), result def _iter_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in `prompt`. @@ -832,6 +817,8 @@ def _iter_placeholders( Note that empty matches are ignored. """ prompt_len = len(prompt) + mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} + item_idx_by_modality = defaultdict[str, int](lambda: 0) start_idx = 0 @@ -843,9 +830,9 @@ def _iter_placeholders( if item_idx >= mm_item_counts.get(modality, 0): continue - for update_info in modality_updates: - content = update_info.get_content(item_idx) - content_tokens_full = content.full.token_ids + for update in modality_updates[item_idx]: + content = update.content + content_tokens_full = _seq2tokens(tokenizer, content.full) content_len_full = len(content_tokens_full) end_idx_full = start_idx + content_len_full @@ -855,7 +842,8 @@ def _iter_placeholders( if prompt[start_idx:end_idx_full] == content_tokens_full: content_is_embed = content.is_embed if content_is_embed is not None: - content_is_embed = content_is_embed(content.full) + content_is_embed = content_is_embed( + tokenizer, content.full) yield PlaceholderFeaturesInfo( modality=modality, @@ -879,29 +867,14 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) + it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) return dict(full_groupby_modality(it)) -class ProcessingCache(MultiModalCache): - - def __init__(self, capacity_gb: float) -> None: - super().__init__() - - self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) - - self.get = self._cache.get - self.put = self._cache.put - self.reset = self._cache.clear - - -_CacheItemOrHash = Union[MultiModalKwargsItem, str] - - class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" @@ -985,8 +958,28 @@ _I = TypeVar("_I", bound=BaseProcessingInfo) MultiModalHashes = dict[str, list[str]] """ A collection of hashes with a similar structure as -[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]. +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. +""" + +MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]] """ +A collection of prompt updates with a similar structure as +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. +""" + +MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]] +""" +For an item `MultiModalPromptUpdates[k][i]`, +`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the +`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the +`ResolvedPromptUpdate` instances have been applied. +""" + + +class MultiModalProcessingInfo(NamedTuple): + kwargs: MultiModalKwargsOptionalItems + hashes: MultiModalHashes + prompt_updates: MultiModalPromptUpdates class BaseMultiModalProcessor(ABC, Generic[_I]): @@ -996,11 +989,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Not to be confused with `transformers.ProcessorMixin`. """ - def __init__(self, - info: _I, - dummy_inputs: "BaseDummyInputsBuilder[_I]", - *, - cache: Optional[ProcessingCache] = None) -> None: + def __init__( + self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional["BaseMultiModalProcessorCache"] = None, + ) -> None: super().__init__() self.info = info @@ -1026,8 +1021,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: str, mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: - return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + return self.apply(prompt, + mm_data, + hf_processor_mm_kwargs, + mm_hash_overrides=mm_hash_overrides) def _get_data_parser(self) -> MultiModalDataParser: """ @@ -1095,7 +1096,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: """ Given the original multi-modal items for this modality @@ -1113,14 +1114,60 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ raise NotImplementedError + def _bind_and_group_updates( + self, + prompt_updates: Sequence[PromptUpdate], + mm_item_counts: Mapping[str, int], + ) -> MultiModalPromptUpdates: + return { + modality: [[update.resolve(item_idx) for update in updates] + for item_idx in range(mm_item_counts.get(modality, 0))] + for modality, updates in full_groupby_modality(prompt_updates) + } + + def _get_mm_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> MultiModalPromptUpdates: + unbound_prompt_updates = self._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates, + mm_items.get_all_counts(), + ) + + for modality, prompt_updates in mm_prompt_updates.items(): + for item_idx, item_prompt_updates in enumerate(prompt_updates): + if len(item_prompt_updates) > 1: + logger.warning_once( + "Detected %d prompt updates for `mm_items[%r][%s]`. " + "Multiple prompt updates per item is now " + "deprecated and may be removed in v0.13. " + "Instead, please specify dynamic update targets " + "in the same prompt update definition by passing " + "a function to `PromptUpdate.target`.", + len(prompt_updates), + modality, + item_idx, + ) + + return mm_prompt_updates + def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_updates, new_token_ids, - mm_item_counts) + tokenizer = self.info.get_tokenizer() + + return find_mm_placeholders(new_token_ids, mm_prompt_updates, + tokenizer) def _get_hf_mm_data( self, @@ -1311,76 +1358,154 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return prompt_ids, mm_processed_data, False + def _hash_mm_items( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, + ) -> MultiModalHashes: + """Create MM hashes to be returned (only used in V1). + + + Note: When overrides are provided via callers of `apply`, + `_hash_mm_items` will be bypassed and the overrides will be used. + """ + model_id = self.info.model_id + + hashes: MultiModalHashes = {} + mm_hash_overrides = mm_hash_overrides or {} + + for modality, items in mm_items.items(): + if modality in mm_hash_overrides: + mm_hashes = mm_hash_overrides[modality] + if isinstance(mm_hashes, str): + mm_hashes = [mm_hashes] + + # For None entries, compute a hash; otherwise, use provided ID. + computed: list[str] = [] + for i, item in enumerate(items): + mm_hash = mm_hashes[i] + + # NOTE: Even if a mm_hash is provided, we still compute a + # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` + # are provided. This is because the processed multimodal + # inputs can be different depending on the processor kwargs. + if mm_hash is None or \ + hf_processor_mm_kwargs or \ + tokenization_kwargs: + + # NOTE: use provided hash string to hash with kwargs + # if available for better performance. + item = mm_hash if mm_hash is not None else item + computed.append( + MultiModalHasher.hash_kwargs( + model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + **tokenization_kwargs)) + else: + computed.append(mm_hash) + hashes[modality] = computed + else: + hashes[modality] = [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + **tokenization_kwargs) + for item in items + ] + + return hashes + def _get_cache_missing_items( self, - cache: ProcessingCache, + cache: "BaseMultiModalProcessorCache", mm_data_items: MultiModalDataItems, mm_hashes: MultiModalHashes, - ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]: - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = { - modality: [(h if (v := cache.get(h)) is None else v) - for h in hashes] + ) -> MultiModalDataItems: + mm_is_cached = { + modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() } mm_missing_idxs = { modality: [ - idx for idx, item_or_hash in enumerate(items_or_hashes) - if isinstance(item_or_hash, str) + idx for idx, item_is_cached in enumerate(items_is_cached) + if not item_is_cached ] - for modality, items_or_hashes in mm_cache_items_or_hashes.items() + for modality, items_is_cached in mm_is_cached.items() } mm_missing_data = { modality: [mm_data_items[modality][idx] for idx in idxs] for modality, idxs in mm_missing_idxs.items() } - return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data) + return self._to_mm_items(mm_missing_data) - def _hash_mm_items( + def _recompute_cached_prompt_update( self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> MultiModalHashes: - """Create MM hashes to be returned (only used in V1).""" - model_id = self.info.model_id - - return { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs, - **tokenization_kwargs) - for item in items - ] - for modality, items in mm_items.items() - } + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + """ + Override this if other attributes of `ResolvedPromptUpdate` + also need to be recomputed after retrieving from the cache. + """ + return replace(cached_update, item_idx=new_item_idx) def _merge_mm_kwargs( self, - cache: ProcessingCache, - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], - mm_missing_kwargs: MultiModalKwargs, - ) -> dict[str, list[MultiModalKwargsItem]]: + cache: "BaseMultiModalProcessorCache", + mm_hashes: MultiModalHashes, + mm_missing_kwargs: MultiModalKwargsItems, + mm_missing_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: + # Need to calculate this at the beginning to avoid skipping cache logic + # for subsequently repeated items in the same modality + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } + mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) - for modality, items_or_hashes in mm_cache_items_or_hashes.items(): - for item_or_hash in items_or_hashes: - if isinstance(item_or_hash, str): - kw_item = mm_missing_kwargs.get_item( - modality, - mm_missing_next_idx[modality], - ) - cache.put(item_or_hash, kw_item) + merged_kwargs = defaultdict[str, + list[Optional[MultiModalKwargsItem]]](list) + merged_prompt_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, hashes in mm_hashes.items(): + missing_kwargs = mm_missing_kwargs.get(modality, []) + missing_prompt_updates = mm_missing_prompt_updates.get( + modality, []) + + for item_idx, item_hash in enumerate(hashes): + kwargs: Optional[MultiModalKwargsItem] + if not mm_is_cached[modality][item_idx]: + missing_next_idx = mm_missing_next_idx[modality] + kwargs = missing_kwargs[missing_next_idx] + updates = missing_prompt_updates[missing_next_idx] + mm_missing_next_idx[modality] += 1 + + item = kwargs, updates else: - kw_item = item_or_hash + item = None - merged_items[modality].append(kw_item) + kwargs, updates = cache.get_and_update_item(item, item_hash) - return dict(merged_items) + merged_kwargs[modality].append(kwargs) + merged_prompt_updates[modality].append([ + self._recompute_cached_prompt_update(update, item_idx) + for update in updates + ]) + + mm_kwargs = MultiModalKwargsItems(merged_kwargs) + mm_prompt_updates = dict(merged_prompt_updates) + + return mm_kwargs, mm_prompt_updates def _apply_hf_processor( self, @@ -1389,8 +1514,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, mm_processed_data, @@ -1403,17 +1529,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): enable_hf_prompt_update=True, ) - mm_kwargs = MultiModalKwargs.from_hf_inputs( + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_processed_data, self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), ) - mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, - tokenization_kwargs) - if return_mm_hashes else None) + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = self._hash_mm_items(mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) + + mm_prompt_updates = self._get_mm_prompt_updates( + mm_data_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes, + prompt_updates=mm_prompt_updates, + ) - return prompt_ids, mm_kwargs, mm_hashes, is_update_applied + return prompt_ids, mm_info, is_update_applied def _cached_apply_hf_processor( self, @@ -1422,8 +1562,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, caching the results and reusing cached results. @@ -1437,22 +1578,20 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) - mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, - tokenization_kwargs) - ( - mm_cache_items_or_hashes, - mm_missing_data_items, - ) = self._get_cache_missing_items( + mm_hashes = self._hash_mm_items(mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) + + mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, mm_hashes=mm_hashes, ) - mm_hashes_to_return = mm_hashes if return_mm_hashes else None - # NOTE: `prompt` does not correspond to `mm_missing_data_items`, # so we can't apply prompt updates until the new multimodal # items are combined with the cached multimodal items @@ -1468,66 +1607,60 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): enable_hf_prompt_update=False, ) - mm_missing_kwargs = MultiModalKwargs.from_hf_inputs( + mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_missing_processed_data, self._get_mm_fields_config(mm_missing_processed_data, hf_processor_mm_kwargs), ) - mm_cache_items_merged = self._merge_mm_kwargs( + mm_missing_prompt_updates = self._get_mm_prompt_updates( + mm_missing_data_items, + hf_processor_mm_kwargs, + mm_missing_kwargs, + ) + + mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( cache, - mm_cache_items_or_hashes=mm_cache_items_or_hashes, + mm_hashes=mm_hashes, mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates=mm_missing_prompt_updates, ) - mm_kwargs = MultiModalKwargs([ - item for cache_items in mm_cache_items_merged.values() - for item in cache_items - ]) - - return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied - - def _bind_and_group_updates( - self, - prompt_updates: Sequence[PromptUpdate], - ) -> dict[str, Sequence[BoundPromptUpdate]]: - tokenizer = self.info.get_tokenizer() + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes, + prompt_updates=mm_prompt_updates, + ) - it = (update.bind(tokenizer) for update in prompt_updates) - return dict(full_groupby_modality(it)) + return prompt_ids, mm_info, is_update_applied def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - return apply_token_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_token_matches(prompt, mm_prompt_updates, tokenizer) def _apply_text_matches( self, prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> str: - return apply_text_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[str, MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_text_matches(prompt, mm_prompt_updates, tokenizer) def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() - mm_token_matches = { - modality: find_token_matches(token_ids, updates) - for modality, updates in mm_prompt_updates.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } + new_token_ids, match_result = self._apply_token_matches( + token_ids, + mm_prompt_updates, + ) # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -1540,59 +1673,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # of the search text in the prompt, we instead perform string-based # updates on the decoded token IDs, then encode them back. if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = self._apply_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, + all(update_idx is not None for update_idx in update_idxs) + for update_idxs in match_result.values()): + new_text = decode_tokens(tokenizer, new_token_ids) + else: + new_text, match_result = self._apply_text_matches( + decode_tokens(tokenizer, token_ids), + mm_prompt_updates, ) - text = decode_tokens(tokenizer, token_ids) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } - else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, updates) - for modality, updates in mm_prompt_updates.items() - } - text = self._apply_text_matches( - text, - mm_text_matches, - mm_item_counts, + new_token_ids = encode_tokens( + tokenizer, + new_text, + add_special_tokens=False, ) - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } + matched_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, update_idxs in match_result.items(): + for item_idx, update_idx in enumerate(update_idxs): + assert update_idx is not None, ( + "Failed to apply prompt replacement for " + f"mm_items[{modality!r}][{item_idx}]") + + matched_updates[modality].append( + [mm_prompt_updates[modality][item_idx][update_idx]]) placeholders = self._find_mm_placeholders( - matched_updates, - token_ids, - mm_item_counts, + new_token_ids, + dict(matched_updates), ) - return token_ids, text, placeholders + return new_token_ids, new_text, placeholders def _validate_mm_kwargs( self, - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsOptionalItems, mm_item_counts: Mapping[str, int], ) -> None: for modality, item_count in mm_item_counts.items(): - if modality in mm_kwargs.modalities: - items = mm_kwargs.get_items(modality) - else: - items = [] + items = mm_kwargs.get(modality, []) if len(items) != item_count: raise RuntimeError( @@ -1628,27 +1748,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsOptionalItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) @@ -1662,7 +1773,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) @@ -1674,7 +1784,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1696,23 +1808,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ( prompt_ids, - mm_kwargs, - mm_hashes, + mm_info, is_update_applied, ) = self._cached_apply_hf_processor( prompt, mm_items, hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) # NOTE: tokenization_kwargs are not required to init processor prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, prompt_ids=prompt_ids, - mm_kwargs=mm_kwargs, + mm_kwargs=mm_info.kwargs, + mm_prompt_updates=mm_info.prompt_updates, is_update_applied=is_update_applied, ) @@ -1725,8 +1836,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): type="multimodal", prompt=prompt, prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, + mm_kwargs=mm_info.kwargs, + mm_hashes=mm_info.hashes, mm_placeholders=mm_placeholder_ranges, ) @@ -1789,7 +1900,9 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1804,7 +1917,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return self._get_enc_dec_inputs( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index d876887fc155d6c00bb831706804b79a5bdae8b5..ffc69a2db60a4f82a64b6b47a6f3d3af4d4bda44 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargs, + MultiModalInputs, MultiModalKwargsOptionalItems, MultiModalPlaceholderDict) from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, EncDecMultiModalProcessor) @@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): """Dummy data used for profiling.""" prompt_token_ids: list[int] - multi_modal_data: MultiModalKwargs + multi_modal_data: MultiModalKwargsOptionalItems multi_modal_placeholders: MultiModalPlaceholderDict @@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]): if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. + # NOTE: Whisper and Donut allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index ded56cca8099984f878560a7d0439657e435867d..38adbf8f3536a5a1c2b44cf37235c45a5c2c3e9f 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from functools import lru_cache from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn @@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - ProcessingCache) +from .cache import (BaseMultiModalProcessorCache, + processor_only_cache_from_config) +from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]): info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[_I]: ... @@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]): self, ctx: InputProcessingContext, *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ): info = self.info(ctx) dummy_inputs_builder = self.dummy_inputs(info) return self.processor(info, dummy_inputs_builder, cache=cache) -# Make sure a different cache is used for each model config -# NOTE: ModelConfig is not hashable so it cannot be passed directly -@lru_cache(maxsize=1) -def _get_processor_cache(model_id: str, capacity_gb: int): - return ProcessingCache(capacity_gb) if capacity_gb > 0 else None - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. @@ -103,31 +96,6 @@ class MultiModalRegistry: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - def _get_processor_cache(self, model_config: "ModelConfig"): - model_id = model_config.model - capacity_gb = model_config.mm_processor_cache_gb - return _get_processor_cache(model_id, capacity_gb) - - def reset_processor_cache(self, model_config: "ModelConfig") -> bool: - """Reset the multi-modal processing cache.""" - if processor_cache := self._get_processor_cache(model_config): - processor_cache.reset() - - return True # Success - - def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool: - """Whether the multi-modal input cache should be enabled. - NOTE: This is put under MultiModalRegistry on purpose to respect - text-only mode for multimodal models. - """ - - if not self.supports_multimodal_inputs(model_config): - return False - - mm_config = model_config.get_multimodal_config() - - return mm_config.mm_processor_cache_gb > 0 - def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: """ Checks if the model supports multimodal inputs. @@ -157,6 +125,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -165,11 +135,11 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) return profiler.get_mm_max_contiguous_tokens( seq_len, @@ -182,6 +152,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_nonzero_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -192,15 +164,19 @@ class MultiModalRegistry: This is currently directly used only in V1 for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() if mm_limits[key] > 0 } + # TODO: Remove once V0 is gone def get_max_tokens_by_modality( self, model_config: "ModelConfig", @@ -209,14 +185,19 @@ class MultiModalRegistry: Get the maximum number of tokens from each modality for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + cache = processor_only_cache_from_config(model_config, self) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() } + # TODO: Remove once V0 is gone def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens @@ -227,6 +208,8 @@ class MultiModalRegistry: def get_mm_limits_per_prompt( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality @@ -235,7 +218,7 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -303,7 +286,7 @@ class MultiModalRegistry: model_config: "ModelConfig", *, tokenizer: Optional[AnyTokenizer] = None, - disable_cache: Optional[bool] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. @@ -311,15 +294,10 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if disable_cache is None: - disable_cache = not model_config.enable_mm_processor_cache - model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] ctx = self._create_processing_ctx(model_config, tokenizer) - cache = None if disable_cache else self._get_processor_cache( - model_config) return factories.build_processor(ctx, cache=cache) @@ -328,13 +306,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) @@ -352,13 +332,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) @@ -372,3 +354,22 @@ class MultiModalRegistry: ) return dummy_data + + def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: + """ + Get the maximum length of the encoder input for encoder-decoder models. + """ + if not model_config.is_encoder_decoder: + return 0 + max_tokens = self.\ + get_max_tokens_per_item_by_nonzero_modality(model_config) + if not max_tokens: + # TODO - this function assumes encoder-decoder models are + # multimodal. This will need to change when adding support for more + # than whisper. + return 0 + assert len(max_tokens) == 1, "Encoder-decoder models are expected \ + to implement the multimodal interface with at most one modality." + + first_modality = next(iter(max_tokens)) + return max_tokens[first_modality] diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index a80f09bb19272796e3bdcf80494167640a8bcec3..834b2189e4bed8f44d679996664ce0f4fa28ec4b 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,6 +3,8 @@ import asyncio import atexit +import itertools +import math from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import groupby @@ -32,11 +34,13 @@ _M = TypeVar("_M") if TYPE_CHECKING: from .inputs import (BatchedTensorInputs, MultiModalKwargs, - MultiModalKwargsItem, MultiModalPlaceholderDict) + MultiModalKwargsItem, MultiModalKwargsItems, + MultiModalPlaceholderDict) else: BatchedTensorInputs = Any MultiModalKwargs = Any MultiModalKwargsItem = Any + MultiModalKwargsItems = Any MultiModalPlaceholderDict = Any global_thread_pool = ThreadPoolExecutor( @@ -359,18 +363,20 @@ def argsort_mm_positions( "`group_mm_kwargs_by_modality` and will be removed in v0.13. " "Please use `group_mm_kwargs_by_modality` instead.") def group_mm_inputs_by_modality( - mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]: + mm_inputs: list[MultiModalKwargsItems] +) -> list[list[MultiModalKwargsItems]]: if not mm_inputs: return [] - def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]: + def modality_group_func( + mm_input: MultiModalKwargsItems) -> Union[str, int]: # If the input has multiple modalities, return a id as the unique key # for the mm_input input. - if len(mm_input.modalities) > 1: + if len(mm_input) > 1: return id(mm_input) - elif len(mm_input.modalities) == 1: - return list(mm_input.modalities)[0] + elif len(mm_input) == 1: + return next(iter(mm_input.keys())) # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty, # this is used to make InternVL with legacy pipeline still work with v1. @@ -397,12 +403,12 @@ def group_mm_kwargs_by_modality( Yields: A tuple `(modality, num_items, grouped_kwargs)`. """ - from vllm.multimodal.inputs import MultiModalKwargs + from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) - # mm_kwargs_group = MultiModalKwargs(items_lst) \ + # mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \ # .get_data(pin_memory=pin_memory) # if device is not None: @@ -417,7 +423,10 @@ def group_mm_kwargs_by_modality( # We will also need to update each model to remove `flatten_bn`. mm_kwargs_group = MultiModalKwargs.as_kwargs( MultiModalKwargs.batch( - [MultiModalKwargs([item]) for item in items_lst], + [ + MultiModalKwargsItems.from_seq([item]).get_data() + for item in items_lst + ], pin_memory=pin_memory, ), device=device, @@ -452,12 +461,227 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, num_chunks_per_rank, ...] vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0) vision_embeddings = vision_embeddings[:num_chunks, ...] return vision_embeddings +def get_load_balance_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus=2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted(range(n_samples), + key=lambda i: sizes[i], + reverse=True) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], +) -> tuple[torch.Tensor, ...]: + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size=2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, + grouped_pixel_values_len) = get_load_balance_assignment( + patches_per_image, tp_size) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: + cum_gpu_sample_counts[tp_rank_local + + 1]] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat([ + pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] + for i in image_idxs_local + ]) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty((0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype) + # embed_dim_reduction_factor = 2 * 2 + embed_dim_reduction_factor = (vision_model.spatial_merge_size * + vision_model.spatial_merge_size) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max( + grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, + local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty((0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + padding = torch.empty((padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], + dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather( + image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + (grouped_pixel_values_len[rank] // + embed_dim_reduction_factor) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [(patch_size // embed_dim_reduction_factor) + for patch_size in patches_per_image] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx:current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start:embed_start + img_patches] + embed_start += img_patches + current_idx += count + + out_embeddings = tuple(embed for embed in original_order_embeddings + if embed is not None) + assert len(out_embeddings) == len( + original_order_embeddings), "Found unassigned embeddings" + return out_embeddings + + def fetch_audio( audio_url: str, audio_io_kwargs: Optional[dict[str, Any]] = None, diff --git a/vllm/outputs.py b/vllm/outputs.py index 9784a8894472f04b0d0d0e1f70aaa43c895e778f..acdb2f89ce7354bc369a5248da6c8e685b809a14 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -409,7 +409,7 @@ class EmbeddingOutput: Args: embedding: The embedding vector, which is a list of floats. - Its length depends on the hidden dimension of the model. + Its length depends on the hidden dimension of the model. """ embedding: list[float] @@ -447,7 +447,7 @@ class ClassificationOutput: Args: probs: The probability vector, which is a list of floats. - Its length depends on the number of classes. + Its length depends on the number of classes. """ probs: list[float] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 0b16a8e1d1d8b85ab49ba691b87f05fee95c5f2e..12d5e0bf08652191115b9f22721247af12e217df 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -69,6 +69,7 @@ class CpuPlatform(Platform): device_type: str = "cpu" dispatch_key: str = "CPU" dist_backend: str = "gloo" + device_control_env_var = "CPU_VISIBLE_MEMORY_NODES" @property def supported_dtypes(self) -> list[torch.dtype]: @@ -268,7 +269,7 @@ class CpuPlatform(Platform): DEFAULT_MAX_NUM_BATCHED_TOKENS) @classmethod - def get_allowed_cpu_memory_node_list( + def get_allowed_cpu_core_node_list( cls) -> tuple[list[int], list[LogicalCPUInfo]]: assert platform.system() == "Linux" @@ -297,6 +298,13 @@ class CpuPlatform(Platform): allowed_numa_nodes.add(x.numa_node) # type: ignore allowed_numa_nodes_list = sorted(allowed_numa_nodes) + env_key = CpuPlatform.device_control_env_var + if (env_key in os.environ and os.environ[env_key] != ""): + visible_nodes = [int(s) for s in os.environ[env_key].split(',')] + allowed_numa_nodes_list = [ + x for x in visible_nodes if x in allowed_cpu_id_list + ] + return allowed_numa_nodes_list, logical_cpu_list @classmethod @@ -332,5 +340,10 @@ class CpuPlatform(Platform): supplied model configuration. """ arch = cls.get_cpu_architecture() - return (cls.supports_v1(model_config) and arch - in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM)) + return (cls.supports_v1(model_config) + and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, + CpuArchEnum.ARM, CpuArchEnum.S390X)) + + @classmethod + def opaque_attention_op(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 321db8287c0f82fda27fe240d63aebb869b1cf94..5cbb7346436ef931694680f09dc264e67f3483b1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -350,17 +350,7 @@ class CudaPlatformBase(Platform): return FLEX_ATTENTION_V1 # Backends for V0 engine - if selected_backend == _Backend.FLASHINFER: - logger.info("Using FlashInfer backend.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) - logger.info_once( - "Using HND KV cache layout on V1 engine by default for " - "Blackwell (SM 10.0) GPUs.") - set_kv_cache_layout("HND") - return "vllm.attention.backends.flashinfer.FlashInferBackend" - elif selected_backend == _Backend.XFORMERS: + if selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") return "vllm.attention.backends.xformers.XFormersBackend" elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: @@ -416,10 +406,6 @@ class CudaPlatformBase(Platform): if (fp8_kv_cache and not flash_attn_supports_fp8()): logger.info( "Cannot use FlashAttention backend for FP8 KV cache.") - logger.warning( - "Please use FlashInfer backend with FP8 KV Cache for " - "better performance by setting environment variable " - "VLLM_ATTENTION_BACKEND=FLASHINFER") target_backend = _Backend.XFORMERS except ImportError: logger.info( @@ -456,6 +442,10 @@ class CudaPlatformBase(Platform): def use_custom_allreduce(cls) -> bool: return True + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @@ -495,18 +485,63 @@ class CudaPlatformBase(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") - will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + attention_backend = envs.VLLM_ATTENTION_BACKEND + supported = False - if cls.is_device_capability(100): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() + if model_config is not None and model_config.use_mla: + # Default to CutlassMLA for blackwell, + # FlashMLA otherwise + if attention_backend is None: + if cls.is_device_capability(100): + attention_backend = "CUTLASS_MLA" + else: + attention_backend = "FLASHMLA" + + # Only FlashMLA supports fp8 + if attention_backend == "FLASHMLA": + supported = True + else: + supported = (not fp8_attention) + else: + # Default to FlashAttention + if attention_backend is None: + attention_backend = "FLASH_ATTN_VLLM_V1" + + # All Blackwell backends support fp8 + if cls.is_device_capability(100): + supported = True + elif attention_backend == "FLASH_ATTN_VLLM_V1": + if fp8_attention: + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + else: + supported = True return supported + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4017f1ca7eecbe8bf56c85b30a86af7701545e0b..ad12f7f788cf8ebaee09f54d6ed6aef88a62ff76 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -81,6 +81,7 @@ class CpuArchEnum(enum.Enum): X86 = enum.auto() ARM = enum.auto() POWERPC = enum.auto() + S390X = enum.auto() OTHER = enum.auto() UNKNOWN = enum.auto() @@ -377,6 +378,8 @@ class Platform: return CpuArchEnum.ARM elif machine.startswith("ppc"): return CpuArchEnum.POWERPC + elif machine == "s390x": + return CpuArchEnum.S390X return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN @@ -506,6 +509,14 @@ class Platform: """ return False + @classmethod + def opaque_attention_op(cls) -> bool: + """ + Returns True if we register attention as one giant opaque custom op + on the current platform + """ + return False + @classmethod def validate_request( cls, @@ -526,7 +537,7 @@ class Platform: def get_global_graph_pool(self) -> Any: """ - Return the global graph pool for the this platform. + Return the global graph pool for this platform. """ cls = self.__class__ if cls._global_graph_pool is None: @@ -562,12 +573,20 @@ class Platform: raise RuntimeError(f"Unsupported torch distributed backend: {backend}") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. """ return False + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + """ + Check if the dtype is supported by the current platform. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f9825cab30c7e95c713aac54cdf1b0a712e4af7e..78d383b5952bf26751b2b8be4d811aa0fcfe2653 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -180,9 +180,13 @@ class RocmPlatform(Platform): # rocm shares the same device control env var as CUDA device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + # supported_quantization: list[str] = [ + # "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", + # "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "moe_wna16", "slimquant_w4a8","w8a8_int8","awq_marlin","slimquant_w4a8_marlin" + # ] supported_quantization: list[str] = [ - "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4", "moe_wna16", "slimquant_w4a8","w8a8_int8","awq_marlin","slimquant_w4a8_marlin" + "awq", "gptq", "compressed-tensors", "gguf", + "quark", "moe_wna16", "slimquant_w4a8", "w8a8_int8", "awq_marlin", "slimquant_w4a8_marlin" ] @classmethod @@ -467,6 +471,10 @@ class RocmPlatform(Platform): supported_archs = ['gfx94', 'gfx95'] return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( @@ -515,5 +523,26 @@ class RocmPlatform(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 7d618256485af05d2266e8ac0ac324174afd49ae..02d795c99c445292c2bac80c227a793c8d910f1f 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -25,6 +25,8 @@ else: logger = init_logger(__name__) +USE_TPU_COMMONS = False + class TpuPlatform(Platform): _enum = PlatformEnum.TPU @@ -195,13 +197,15 @@ class TpuPlatform(Platform): raise ValueError("Torch XLA does not support per-request seed.") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform TpuPlatform = TpuCommonsPlatform # type: ignore + USE_TPU_COMMONS = True except ImportError: logger.info("tpu_commons not found, using vLLM's TpuPlatform") pass diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index af24437f649f4c38713bcc1847c5409b5d1b0928..84f4cd725646504d691ed0ca9136fab7e4e2b3d4 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -90,28 +90,14 @@ class XPUPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 64 - # FIXME: Temporarily forcing eager mode - # remove after t.compile support stabilizes. - if (envs.VLLM_USE_V1 and model_config is not None - and not vllm_config.model_config.enforce_eager): - from vllm.config import CompilationLevel - vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 - - # Instances created using VllmConfig() typically have model_config as - # None by default. The modification involves adding a check to prevent - # potential null exceptions check and update model config. - if model_config is not None and model_config.dtype == torch.bfloat16 \ - and not cls.device_support_bf16(): - model_config.dtype = torch.float16 - # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config if compilation_config.cudagraph_mode is None or \ compilation_config.cudagraph_mode.max_cudagraph_mode() \ != CUDAGraphMode.NONE: - logger.info("[XPU] CUDA graph is not supported on XPU, " - "disabling cudagraphs.") + logger.info("[XPU] CUDA graph is not supported on XPU, disabling " + "cudagraphs. Fallback to cudagraph_mode=NONE") compilation_config.cudagraph_mode = CUDAGraphMode.NONE # check and update parallel config @@ -162,30 +148,11 @@ class XPUPlatform(Platform): torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) - @classmethod - def device_support_bf16(cls) -> bool: - device_name = cls.get_device_name().lower() - if cls.is_client_gpu_a770(): - logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," - " fallback to float16") - return False - else: - logger.info( - "Device name %s supports bfloat16. Please file an issue " - "if you encounter any accuracy problems with bfloat16.", - device_name) - return True - @classmethod def is_data_center_gpu(cls) -> bool: device_name = cls.get_device_name().lower() return device_name.count("data center gpu") > 0 - @classmethod - def is_client_gpu_a770(cls) -> bool: - device_name = cls.get_device_name().lower() - return device_name.count("a770") > 0 - @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa @@ -197,3 +164,18 @@ class XPUPlatform(Platform): @classmethod def device_count(cls) -> int: return torch.xpu.device_count() + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + device_name = cls.get_device_name().lower() + # client gpu a770 + if device_name.count("a770") > 0: + raise ValueError( + "Intel Arc A770 have bfloat16 accuracy known issue. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + + @classmethod + def opaque_attention_op(cls) -> bool: + return True diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c4f6f8d97c35c681548b39a1b03dea9785c062 --- /dev/null +++ b/vllm/plugins/io_processors/__init__.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging +from typing import Optional + +from vllm.config import VllmConfig +from vllm.plugins import load_plugins_by_group +from vllm.plugins.io_processors.interface import IOProcessor +from vllm.utils import resolve_obj_by_qualname + +logger = logging.getLogger(__name__) + + +def get_io_processor( + vllm_config: VllmConfig, + plugin_from_init: Optional[str] = None) -> IOProcessor | None: + # Input.Output processors are loaded as plugins under the + # 'vllm.io_processor_plugins' group. Similar to platform + # plugins, these plugins register a function that returns the class + # name for the processor to install. + + if plugin_from_init: + model_plugin = plugin_from_init + else: + # A plugin can be specified via the model config + # Retrieve the model specific plugin if available + # This is using a custom field in the hf_config for the model + hf_config = vllm_config.model_config.hf_config.to_dict() + config_plugin = hf_config.get("io_processor_plugin") + model_plugin = config_plugin + + if model_plugin is None: + logger.info("No IOProcessor plugins requested by the model") + return None + + logger.debug("IOProcessor plugin to be loaded %s", model_plugin) + + # Load all installed plugin in the group + multimodal_data_processor_plugins = \ + load_plugins_by_group('vllm.io_processor_plugins') + + loadable_plugins = {} + for name, func in multimodal_data_processor_plugins.items(): + try: + assert callable(func) + processor_cls_qualname = func() + if processor_cls_qualname is not None: + loadable_plugins[name] = processor_cls_qualname + except Exception: + logger.warning("Failed to load plugin %s.", name, exc_info=True) + + num_available_plugins = len(loadable_plugins.keys()) + if num_available_plugins == 0: + raise ValueError("No IOProcessor plugins installed" + f" but one is required ({model_plugin}).") + + if model_plugin not in loadable_plugins: + raise ValueError( + f"The model requires the '{model_plugin}' IO Processor plugin " + "but it is not installed. " + f"Available plugins: {list(loadable_plugins.keys())}") + + activated_plugin_cls = loadable_plugins[model_plugin] + + return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config) diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..5c73188d5df51cf81c6599608817d96a5e00830b --- /dev/null +++ b/vllm/plugins/io_processors/interface.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Generic, Optional, TypeVar, Union + +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import IOProcessorResponse +from vllm.inputs.data import PromptType +from vllm.outputs import PoolingRequestOutput + +IOProcessorInput = TypeVar('IOProcessorInput') +IOProcessorOutput = TypeVar('IOProcessorOutput') + + +class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + raise NotImplementedError + + async def pre_process_async( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + return self.pre_process(prompt, request_id, **kwargs) + + @abstractmethod + def post_process(self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs) -> IOProcessorOutput: + raise NotImplementedError + + async def post_process_async( + self, + model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + collected_output = [item async for i, item in model_output] + return self.post_process(collected_output, request_id, **kwargs) + + @abstractmethod + def parse_request(self, request: Any) -> IOProcessorInput: + raise NotImplementedError + + @abstractmethod + def output_to_response( + self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + raise NotImplementedError \ No newline at end of file diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 29f037b4372cdebc887bd98970822e2cbb00e254..6672392b8d08084185e2b536b6ac3b98131771de 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Annotated, Any, Optional import msgspec @@ -27,6 +27,11 @@ class PoolingParams( the classification outputs. softmax: Whether to apply softmax to the reward outputs. """ + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=-1)]] = None + """If set to -1, will use the truncation size supported by the model. If + set to an integer k, will use only the last k tokens from the prompt + (i.e., left truncation). If set to `None`, truncation is disabled.""" ## for embeddings models dimensions: Optional[int] = None diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index b2452b95c1c67c58e3390e3e94372ca8890d76ca..9deec8a1e8fb06b74a0fe605ae4aa5c784e363ba 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -30,7 +30,7 @@ class HunyuanA13BReasoningParser(ReasoningParser): Key Features: - For non-stream output , Recognizes and extracts reasoning ("think") and answer ("answer") sections from text using regular expressions. - - For stream process, it require a token id sequences to change the + - For stream process, it requires a token id sequences to change the reasoning state and other state so it maintains internal state to manage parsing across multiple token. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index df4cca9ba1147b238bffad7908e78ca053cc7f20..c7b4ba34c602e964d5c8139e8a3ee1079eaab428 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -182,7 +182,8 @@ class SamplingParams( optionally prompt tokens as a first argument.""" include_stop_str_in_output: bool = False """Whether to include the stop strings in output text.""" - truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=-1)]] = None """If set to -1, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled.""" @@ -241,7 +242,8 @@ class SamplingParams( spaces_between_special_tokens: bool = True, logits_processors: Optional[list[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta(ge=1)]] = None, + msgspec.Meta( + ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, @@ -411,9 +413,11 @@ class SamplingParams( raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None - and self.truncate_prompt_tokens < 1): - raise ValueError(f"truncate_prompt_tokens must be >= 1, " - f"got {self.truncate_prompt_tokens}") + and (self.truncate_prompt_tokens == 0 + or self.truncate_prompt_tokens < -1)): + raise ValueError( + f"truncate_prompt_tokens must be an integer >= 1 or -1, " + f"got {self.truncate_prompt_tokens}") assert isinstance(self.stop_token_ids, list) if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): raise ValueError(f"stop_token_ids must contain only integers, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 9e4ea4981be7857b970d370af7b90fd6a283ff75..2afe5b8717e1877b441f2d0bd1357a8398879178 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -17,14 +17,17 @@ import torch from vllm import envs from vllm.inputs import SingletonInputs -from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams if TYPE_CHECKING: + from vllm.lora.request import LoRARequest from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) +else: + LoRARequest = Any + KVConnectorOutput = Any VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -145,18 +148,7 @@ class SequenceDataDelta( class SequenceData(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence. - - Args: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. Set to an empty list if - None. - - Attributes: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. - cumulative_logprob: The cumulative log probability of the output. - """ + """Data associated with a sequence.""" # NOTE: we cannot use Union[list, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array @@ -256,10 +248,12 @@ class SequenceData(msgspec.Struct, @property def cumulative_logprob(self) -> float: + """The cumulative log probability of the output.""" return self._cumulative_logprob @property def prompt_token_ids(self) -> tuple[int, ...]: + """The token IDs of the prompt.""" return self._prompt_token_ids_tuple @prompt_token_ids.setter @@ -277,6 +271,7 @@ class SequenceData(msgspec.Struct, @property def output_token_ids(self) -> tuple[int, ...]: + """The token IDs of the output.""" return tuple(self._output_token_ids) @output_token_ids.setter @@ -522,16 +517,10 @@ class Sequence: return [0] * len(self.inputs["prompt_embeds"]) return self.inputs["prompt_token_ids"] - @property - def token_type_ids(self) -> list[int]: - if self.inputs["type"] == "embeds": - return [] - return self.inputs.get("token_type_ids", []) - @property def multi_modal_data(self) -> MultiModalKwargs: if self.inputs["type"] == "multimodal": - return self.inputs["mm_kwargs"] + return self.inputs["mm_kwargs"].get_data() return MultiModalKwargs() @@ -779,10 +768,6 @@ class SequenceGroup: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) - @property - def token_type_ids(self) -> Optional[list[int]]: - return self.first_seq.token_type_ids - @property def multi_modal_data(self) -> MultiModalKwargs: if self.first_seq.multi_modal_data: @@ -948,7 +933,7 @@ class SequenceGroupMetadata( omit_defaults=True): # type: ignore[call-arg] """Metadata for a sequence group. Used to create `AttentionMetadata`. - Args: + Attributes: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. seq_data: The sequence data. (Seq id -> sequence data) @@ -958,14 +943,14 @@ class SequenceGroupMetadata( do_sample: True if sampling is required. Sampling is not required when e.g., prefill is chunked, and the current iteration only computes query tokens for prefill, we don't need sampling. - token_chunk_size: The number of tokens to be processed (per sequence). - None if chunking is not required. + pooling_params: Pooling parameters. lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. state: Internal state tied to this sequence group. + token_type_ids: Token type IDs. multi_modal_data: Multi modal data. - mm_processor_kwargs: Multimodal input processor / mapper overrides. + multi_modal_placeholders: Multi modal placeholders. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder @@ -988,7 +973,6 @@ class SequenceGroupMetadata( computed_block_nums: Optional[list[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) - token_type_ids: Optional[list[int]] = None multi_modal_data: Optional[MultiModalKwargs] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None encoder_seq_data: Optional[SequenceData] = None @@ -1051,12 +1035,13 @@ class SequenceOutput( array_like=True): # type: ignore[call-arg] """The model output associated with a sequence. - Args: + Attributes: parent_seq_id: The ID of the parent sequence (for forking in beam search). output_token: The output token ID. logprobs: The logprobs of the output token. (Token id -> logP(x_i+1 | x_0, ..., x_i)) + output_embed: Optional output embedding tensor. """ parent_seq_id: int output_token: int @@ -1149,7 +1134,7 @@ class IntermediateTensors: """ tensors: dict[str, torch.Tensor] - kv_connector_output: Optional["KVConnectorOutput"] + kv_connector_output: Optional[KVConnectorOutput] def __init__(self, tensors): # manually define this function, so that @@ -1174,7 +1159,13 @@ class IntermediateTensors: return len(self.tensors) def __eq__(self, other: object): - return isinstance(other, self.__class__) and self + if not isinstance(other, self.__class__): + return False + if self.tensors.keys() != other.tensors.keys(): + return False + return all( + torch.equal(self.tensors[k], other.tensors[k]) + for k in self.tensors) def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index e0ef7f0999d4771351e6061e408c61ffb94c02ee..d09c5fa924fb05237731e15033cba478249a5252 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -20,6 +20,16 @@ def _get_qwen_chat_template_fallback( return CHAT_TEMPLATES_DIR / "template_basic.jinja" +def _get_minicpmv_chat_template_fallback( + tokenizer_name_or_path: str) -> Optional[Path]: + # MiniCPM-V-4.5 version uses a dedicated template + if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: + return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" + + # Other versions use chatml template + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + # yapf: disable _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", @@ -27,6 +37,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, } diff --git a/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja new file mode 100644 index 0000000000000000000000000000000000000000..661ebd1cf5c17d2828c8a3eed86d1764d5fe2143 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja @@ -0,0 +1,93 @@ +{%- set enable_thinking = enable_thinking | default(false) %} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} + +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '</think>' in message.content %} + {%- set content = message.content.split('</think>')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '<tool_call>\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n</tool_call>' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '<think>\n\n</think>\n\n' }} + {%- endif %} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '<think>\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d8c964fb2a4a4bbe220439c7e0f5e25a37247db0..bec792465bfbbfffe2915f54f8bf3daebe50829c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -14,7 +14,7 @@ from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - HFValidationError, LocalEntryNotFoundError, + LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) from transformers import GenerationConfig, PretrainedConfig @@ -27,19 +27,6 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, - EAGLEConfig, JAISConfig, - KimiVLConfig, MedusaConfig, - MLPSpeculatorConfig, - Nemotron_Nano_VL_Config, - NemotronConfig, OvisConfig, - RWConfig, SpeculatorsConfig, - Step3TextConfig, Step3VLConfig, - UltravoxConfig) -# yapf: enable -from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file if envs.VLLM_USE_MODELSCOPE: @@ -67,24 +54,31 @@ def _get_hf_token() -> Optional[str]: return None -_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { - "chatglm": ChatGLMConfig, - "deepseek_vl_v2": DeepseekVLV2Config, - "kimi_vl": KimiVLConfig, - "Llama_Nemotron_Nano_VL": Nemotron_Nano_VL_Config, - "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) - "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "jais": JAISConfig, - "mlp_speculator": MLPSpeculatorConfig, - "medusa": MedusaConfig, - "eagle": EAGLEConfig, - "speculators": SpeculatorsConfig, - "nemotron": NemotronConfig, - "ovis": OvisConfig, - "ultravox": UltravoxConfig, - "step3_vl": Step3VLConfig, - "step3_text": Step3TextConfig, -} +class LazyConfigDict(dict): + + def __getitem__(self, key): + import vllm.transformers_utils.configs as configs + return getattr(configs, super().__getitem__(key)) + + +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( + chatglm="ChatGLMConfig", + deepseek_vl_v2="DeepseekVLV2Config", + kimi_vl="KimiVLConfig", + Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", + RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) + RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) + jais="JAISConfig", + mlp_speculator="MLPSpeculatorConfig", + medusa="MedusaConfig", + eagle="EAGLEConfig", + speculators="SpeculatorsConfig", + nemotron="NemotronConfig", + ovis="OvisConfig", + ultravox="UltravoxConfig", + step3_vl="Step3VLConfig", + step3_text="Step3TextConfig", +) _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", @@ -335,6 +329,7 @@ def maybe_override_with_speculators_target_model( gguf_model_repo = Path(model).parent else: gguf_model_repo = None + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model if gguf_model_repo is None else gguf_model_repo, revision=revision, @@ -400,6 +395,7 @@ def get_config( raise ValueError(error_message) from e if config_format == ConfigFormat.HF: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model, revision=revision, @@ -459,6 +455,8 @@ def get_config( model, revision, **kwargs) config_dict["max_position_embeddings"] = max_position_embeddings + from vllm.transformers_utils.configs.mistral import adapt_config_dict + config = adapt_config_dict(config_dict) # Mistral configs may define sliding_window as list[int]. Convert it @@ -503,6 +501,24 @@ def get_config( if quantization_config is not None: config.quantization_config = quantization_config + # auto-enable DeepGEMM UE8M0 on Hopper if model config requests it + scale_fmt = quantization_config.get("scale_fmt", None) + if scale_fmt in ("ue8m0", ): + if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"): + os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1" + logger.info_once( + ("Detected quantization_config.scale_fmt=%s; " + "enabling Hopper UE8M0."), + scale_fmt, + ) + elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.warning_once( + ("Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; " + "Hopper UE8M0 disabled."), + scale_fmt, + ) if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) @@ -532,7 +548,7 @@ def try_get_local_file(model: Union[str, Path], revision=revision) if isinstance(cached_filepath, str): return Path(cached_filepath) - except HFValidationError: + except ValueError: ... return None @@ -908,3 +924,42 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: exc_info=e) return max_position_embeddings + + +def get_model_path(model: Union[str, Path], revision: Optional[str] = None): + if os.path.exists(model): + return model + assert huggingface_hub.constants.HF_HUB_OFFLINE + common_kwargs = { + "local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE, + "revision": revision, + } + + if envs.VLLM_USE_MODELSCOPE: + from modelscope.hub.snapshot_download import snapshot_download + return snapshot_download(model_id=model, **common_kwargs) + + from huggingface_hub import snapshot_download + return snapshot_download(repo_id=model, **common_kwargs) + + +def get_hf_file_bytes(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main') -> Optional[bytes]: + """Get file contents from HuggingFace repository as bytes.""" + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) + + if file_path is None: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision, + token=_get_hf_token()) + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path, 'rb') as file: + return file.read() + + return None diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index bc249c5836034933aad0fbc8bda90fca3294927e..6aabf9e5262e68249a8249a6d92c833b7dd97fba 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -61,8 +61,8 @@ class EAGLEConfig(PretrainedConfig): else f"Eagle3{arch}" for arch in self.model.architectures ] else: - raise ValueError(f"Invalid method {method}. \ - Supported methods are eagle and eagle3.") + raise ValueError(f"Invalid method {method}. " + "Supported methods are eagle and eagle3.") super().__init__(**kwargs) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 05ec8a1b98f60a70cfa794fe7a58ba079feac197..88ea8a315e7fa03cf77139f696bc9a77002a2684 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -24,33 +24,38 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. + # Performance improvements: avoid repeated attribute and function lookups; + # localize frequently used objects; + sub_texts: list[str] = [] current_sub_text: list[str] = [] + convert_tokens_to_string = tokenizer.convert_tokens_to_string + added_vocab_set = set(tokenizer.get_added_vocab()) if mode != "cpm": - all_special_tokens = set(tokenizer.all_special_tokens) + all_special_tokens = set( + tokenizer.all_special_tokens) if skip_special_tokens else () else: all_special_tokens = tokenizer._special_token_set + for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: + # Use precomputed set for skip-special check + if token in all_special_tokens: continue - if token in tokenizer.get_added_vocab(): + if token in added_vocab_set: if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] + sub_texts.append(convert_tokens_to_string(current_sub_text)) + current_sub_text.clear() sub_texts.append(token) else: current_sub_text.append(token) if current_sub_text: if mode != "cpm": - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(convert_tokens_to_string(current_sub_text)) else: - sub_text = tokenizer.decode(current_sub_text) - sub_texts.append(sub_text) + sub_texts = tokenizer.decode(current_sub_text) if spaces_between_special_tokens: return " ".join(sub_texts) - else: - return "".join(sub_texts) + return "".join(sub_texts) # 5 is an arbitrary value that should work for all diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index eca4d7c884dd31d47107ce7f210add22964edb39..8a1ad226d99f0ba33a0d8b90140fc1c43833928d 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -11,5 +11,6 @@ reasons: from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) from vllm.transformers_utils.processors.ovis import OvisProcessor +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor -__all__ = ["DeepseekVLV2Processor", "OvisProcessor"] +__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"] diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py new file mode 100644 index 0000000000000000000000000000000000000000..d3273257ff8c20d082860d64d5983bfaafc03f9d --- /dev/null +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from functools import cached_property +from typing import Optional, Union + +import numpy as np +import PIL +import torch +from transformers import AutoProcessor, BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, + Unpack) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +__all__ = ['Ovis2_5Processor'] +IMAGE_TOKEN = "<image>" +VIDEO_TOKEN = "<video>" +MIN_PIXELS = 448 * 448 +MAX_PIXELS = 1792 * 1792 + + +class Ovis2_5ProcessorKwargs(ProcessingKwargs, + total=False): # type: ignore[call-arg] + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + 'convert_to_rgb': True, + 'min_pixels': MIN_PIXELS, + 'max_pixels': MAX_PIXELS, + }, + "videos_kwargs": { + 'convert_to_rgb': True, + 'min_pixels': MIN_PIXELS, + 'max_pixels': MAX_PIXELS, + } + } + + +class Ovis2_5Processor(ProcessorMixin): + r""" + Constructs a Ovis processor which wraps a Ovis image processor + and a Qwen2 tokenizer into a single processor. + [`OvisProcessor`] offers all the functionalities of + [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. + See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] + for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will + be used to convert lists of messages in a chat into + a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "image_pad_token"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + image_pad_token=None, + patch_size=16, + hidden_stride=2, + temporal_patch_size=1, + **kwargs, + ): + self.image_token = IMAGE_TOKEN + self.video_token = VIDEO_TOKEN + self.image_pad_token = "<|image_pad|>" + + self.patch_size = patch_size + self.hidden_stride = hidden_stride + self.temporal_patch_size = temporal_patch_size + super().__init__(image_processor, + tokenizer, + chat_template=chat_template) + + @cached_property + def extra_special_tokens(self): + image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] + extra_special_tokens = { + "image_token": -200, + "video_token": -201, + "visual_atom": -300, + "image_start": -301, + "image_end": -302, + "video_start": -303, + "video_end": -304, + 'image_pad': image_pad_token_id, + } + return extra_special_tokens + + def __call__( + self, + images: ImageInput = None, + videos: Union[np.ndarray, list[ImageInput]] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], + list[PreTokenizedInput]] = None, + **kwargs: Unpack[Ovis2_5ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) + and image(s). This method forwards the `text`and `kwargs` arguments + to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` + is not `None` to encode the text. To prepare the vision inputs, + this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] + if `vision_infos` is not `None`. + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, + `list[PIL.Image.Image]`, `list[np.ndarray]`, + `list[torch.Tensor]`): + The image or batch of images to be prepared. + Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats + are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. + Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as + list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with + a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, + `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video + can be a 4D NumPy array or PyTorch tensor, or a nested + list of 3D frames. Both channels-first and channels-last + formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- list of token ids to be fed to a model. + Returned when `text` is not `None`. + - **attention_mask** -- list of indices specifying which tokens + should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* + is in `self.model_input_names` and if `text` is not `None`). + - **pixel_values** -- Pixel values to be fed to a model. + Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to + a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- list of image 3D grid in LLM. Returned + when `images` is not `None`. + - **video_grid_thw** -- list of video 3D grid in LLM. Returned + when `videos` is not `None`. + - **second_per_grid_ts** -- list of video seconds per time grid. + Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Ovis2_5ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # Process all images first + visual_features = {} + output = BatchFeature() + if images is not None: + processed_images = [] + image_placeholders_list = [] + grids = [] + # Process each image + for image in images if isinstance(images, list) else [images]: + pixel_values, image_placeholders, grid = ( + self.preprocess_multidata( + images=image, **output_kwargs["images_kwargs"])) + processed_images.append(pixel_values) + image_placeholders_list.append(image_placeholders) + grids.append(grid) + + # assign all processed images + if processed_images: + visual_features["image_placeholders"] = image_placeholders_list + output["pixel_values"] = processed_images + output["grids"] = grids + + if videos is not None: + processed_videos = [] + videos_placeholders_list = [] + grids = [] + # Process each video + for video in videos if isinstance(videos, list) else [videos]: + pixel_values, video_placeholders, grid = ( + self.preprocess_multidata( + video=video, **output_kwargs["videos_kwargs"])) + processed_videos.append(pixel_values) + videos_placeholders_list.append(video_placeholders) + grids.append(grid) + # assign all processed videos + if processed_videos: + visual_features[ + "video_placeholders"] = videos_placeholders_list + output["video_pixel_values"] = processed_videos + output["video_grids"] = grids + + # Process text input + if text is not None: + if not isinstance(text, list): + text = [text] + tokenized_batched_text = self._tokenize_with_visual_symbol(text) + image_token_id = self.get_token_value("image_token") + video_token_id = self.get_token_value("video_token") + replaced_ids_list = [] + image_idx = 0 + video_idx = 0 + for ids_tensor in tokenized_batched_text: + has_image_tokens = (image_token_id in ids_tensor + and "image_placeholders" in visual_features + and image_idx < len( + visual_features["image_placeholders"])) + has_video_tokens = (video_token_id in ids_tensor + and "video_placeholders" in visual_features + and video_idx < len( + visual_features["video_placeholders"])) + if has_image_tokens or has_video_tokens: + # Convert to list for easier manipulation + ids_list = ids_tensor.tolist() + new_ids = [] + + # Replace placeholders + for token_id in ids_list: + if token_id == image_token_id: + new_ids.extend( + visual_features["image_placeholders"] + [image_idx]) + image_idx += 1 + elif token_id == video_token_id: + new_ids.extend( + visual_features["video_placeholders"] + [video_idx]) + video_idx += 1 + else: + new_ids.append(token_id) + # Convert back to tensor + ids_tensor = torch.tensor(new_ids, dtype=torch.long) + replaced_ids_list.append(ids_tensor) + if replaced_ids_list: + replaced_and_tokenized_ids = torch.stack(replaced_ids_list) + else: + replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long) + output["input_ids"] = replaced_and_tokenized_ids + + return output + # If only images were provided + return BatchFeature(data=visual_features) + + def _tokenize_with_visual_symbol(self, + text_list: list[str]) -> torch.LongTensor: + batch_token_ids = [] + for text in text_list: + token_ids = [] + video_token_id = self.get_token_value("video_token") + image_token_id = self.get_token_value("image_token") + video_split_texts = text.split(self.video_token) + + for j, video_segment in enumerate(video_split_texts): + image_split_texts = video_segment.split(self.image_token) + text_chunks = [ + self.tokenizer(chunk, add_special_tokens=False).input_ids + for chunk in image_split_texts + ] + segment_tokens = [] + for i, chunk in enumerate(text_chunks): + segment_tokens.extend(chunk) + if i < len(text_chunks) - 1: + segment_tokens.append(image_token_id) + token_ids.extend(segment_tokens) + if j < len(video_split_texts) - 1: + token_ids.append(video_token_id) + + batch_token_ids.append(token_ids) + return torch.tensor(batch_token_ids, dtype=torch.long) + + # Copied from qwen2_vl + def smart_resize(self, + height: int, + width: int, + factor: int = 28, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS): + """Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range + ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if height < factor or width < factor: + print(f"height:{height} or width:{width} must be " + f"larger than factor:{factor}") + if height < width: + width = round(factor / height * width) + height = factor + else: + height = round(factor / width * height) + width = factor + + elif max(height, width) / min(height, width) > 200: + print(f"absolute aspect ratio must be smaller than 200, " + f"got {max(height, width) / min(height, width)}") + if height > width: + height = 200 * width + else: + width = 200 * height + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + def get_token_value(self, tok): + return self.extra_special_tokens[tok] + + def construct_visual_indicators(self, grid, is_video: bool = False): + if is_video: + start_token = self.get_token_value('video_start') + end_token = self.get_token_value('video_end') + else: + start_token = self.get_token_value('image_start') + end_token = self.get_token_value('image_end') + + image_placeholders = [start_token, self.get_token_value('visual_atom')] + if grid[0] * grid[1] > 1: + for r in range(grid[0]): + for c in range(grid[1]): + image_placeholders.append( + self.get_token_value('visual_atom')) + + image_placeholders.append(end_token) + return image_placeholders + + def construct_visual_placeholders(self, grid, is_video: bool = False): + visual_placeholders = self.construct_visual_indicators((1, 1), + is_video) + + image_atom_token_id = self.get_token_value('visual_atom') + # Extract the padding token ID from tokenizer + image_padding_token_id = self.get_token_value('image_pad') + + num_image_atoms = grid[0] * grid[1] * grid[2] + num_image_atoms //= self.hidden_stride**2 + num_image_atoms //= self.temporal_patch_size + + # Create a new list with padding tokens inserted + padded_placeholder_tokens = [] + for token in visual_placeholders: + if token == image_atom_token_id: + padded_placeholder_tokens.extend([image_padding_token_id] * + num_image_atoms) + else: + padded_placeholder_tokens.append(image_padding_token_id) + return padded_placeholder_tokens + + def preprocess_multidata( + self, + images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, + video: Optional[Union[list[PIL.Image.Image], np.ndarray]] = None, + convert_to_rgb: Optional[bool] = True, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + return_tensors: Optional[str] = 'pt', + ): + is_video = False + if images is not None: + if not isinstance(images, list): + images = [images] + elif video is not None: + is_video = True + # type of vidoe in dummy_mm_data is np.ndarray + if isinstance(video, np.ndarray): + images = [] + for i in range(video.shape[0]): + image = PIL.Image.fromarray(video[i].astype(np.uint8)) + images.append(image) + elif isinstance(video, list): + images = video + min_pixels = min(max_pixels if max_pixels is not None else MAX_PIXELS, + min_pixels if min_pixels is not None else MIN_PIXELS) + images = [ + image.convert("RGB") + if convert_to_rgb and image.mode != 'RGB' else image + for image in images + ] + + width, height = images[0].size + resized_height, resized_width = height, width + processed_images = [] + for image in images: + resized_height, resized_width = self.smart_resize( + height, + width, + factor=self.patch_size * self.hidden_stride, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + new_size = dict(height=resized_height, width=resized_width) + image_pt = self.image_processor.preprocess( + image, size=new_size, return_tensors="np")['pixel_values'][0] + + processed_images.append(image_pt) + + patches = np.array(processed_images) + if patches.shape[0] % self.temporal_patch_size != 0: + num_to_pad = self.temporal_patch_size - (patches.shape[0] % + self.temporal_patch_size) + repeats = np.repeat(patches[-1][np.newaxis], num_to_pad, axis=0) + patches = np.concatenate([patches, repeats], axis=0) + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h = resized_height // self.patch_size + grid_w = resized_width // self.patch_size + + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.hidden_stride, + self.hidden_stride, + self.patch_size, + grid_w // self.hidden_stride, + self.hidden_stride, + self.patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, channel * self.temporal_patch_size * + self.patch_size * self.patch_size) + + visual_placeholders = self.construct_visual_placeholders( + [grid_t, grid_h, grid_w], is_video) + return torch.tensor( + flatten_patches), visual_placeholders, torch.tensor( + [[grid_t, grid_h, grid_w]]) + + +AutoProcessor.register("Ovis2_5Processor", Ovis2_5Processor) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2321ff0bae54b667923a61d5a3c7175e49063308..8409b73ad40470822c10c2501649f9447e1e366f 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -7,7 +7,6 @@ import os import warnings from functools import lru_cache from pathlib import Path -from types import MethodType from typing import TYPE_CHECKING, Any, Optional, Union import huggingface_hub @@ -50,12 +49,11 @@ def decode_tokens( `skip_special_tokens=None` means to use the backend's default settings. """ - decode_method = getattr(tokenizer, "_decode", tokenizer.decode) if skip_special_tokens is not None: - return decode_method(token_ids, - skip_special_tokens=skip_special_tokens) + return tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) - return decode_method(token_ids) + return tokenizer.decode(token_ids) def encode_tokens( @@ -144,26 +142,6 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: return cached_tokenizer -def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: - """Patch _pad method to accept `padding_side` for older tokenizers.""" - orig_pad = tokenizer._pad - - def _pad( - self: PreTrainedTokenizer, - *args, - padding_side: Optional[str] = None, - **kwargs, - ): - if padding_side is not None and padding_side != self.padding_side: - msg = ("`padding_side` argument is not supported by " - f"{type(tokenizer).__name__} and will be ignored.") - warnings.warn(msg, stacklevel=2) - - return orig_pad(*args, **kwargs) - - tokenizer._pad = MethodType(_pad, tokenizer) - - def get_tokenizer( tokenizer_name: Union[str, Path], *args, @@ -271,12 +249,6 @@ def get_tokenizer( } tokenizer.add_special_tokens(special_tokens_map) - # NOTE: We can remove this after https://github.com/zai-org/ChatGLM3/issues/1324 - if type(tokenizer).__name__ in ("ChatGLMTokenizer", - "ChatGLM4Tokenizer"): - assert isinstance(tokenizer, PreTrainedTokenizer) - patch_padding_side(tokenizer) - if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( "Using a slow tokenizer. This might cause a significant " diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index a8bb0398dfdb143809bf01fdc823fe9cf6372994..ae8220f9b9dc5249f2970687cd67dc45479e39a8 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -23,6 +23,7 @@ class TokenizerGroup: self.tokenizer_config = tokenizer_config self.enable_lora = enable_lora self.max_input_length = max_input_length + self.truncation_side = tokenizer_config.get("truncation_side", "left") self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) max_loras = tokenizer_config.get("max_loras", 0) self.lora_tokenizers = LRUCache[int, AnyTokenizer]( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 4dd8b2439b3f5b04405ece0aa4cfb32d90cf6347..f545993a5a9801e2c44972afc4fb6b120b4a3316 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -204,18 +204,16 @@ class MistralTokenizer(TokenizerBase): self.version: int = int(_mistral_version_str.split("v")[-1]) tokenizer_ = tokenizer.instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.tekken import ( - SpecialTokenPolicy, Tekkenizer) + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + self.is_tekken = isinstance(tokenizer_, Tekkenizer) from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer) self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - if self.is_tekken: - # Make sure special tokens will not raise - tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE - elif self.is_spm: - pass - else: + self._special_token_policy = (SpecialTokenPolicy.IGNORE + if self.is_tekken else None) + if not (self.is_tekken or self.is_spm): raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") self._vocab = tokenizer_.vocab() @@ -430,7 +428,8 @@ class MistralTokenizer(TokenizerBase): return self.tokenizer.unk_id ids = [_token_to_id(t) for t in tokens] - decoded = self.tokenizer.decode(ids) + decoded = self.tokenizer.decode(ids, + self._special_token_policy) else: decoded = "".join(tokens) else: @@ -444,7 +443,8 @@ class MistralTokenizer(TokenizerBase): if token in special_tokens: if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) + self.tokenizer.decode(regular_tokens, + self._special_token_policy)) regular_tokens = [] decoded_list.append(token) else: @@ -452,7 +452,8 @@ class MistralTokenizer(TokenizerBase): if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) # type: ignore + self.tokenizer.decode(regular_tokens, + self._special_token_policy)) decoded = ''.join(decoded_list) @@ -470,7 +471,7 @@ class MistralTokenizer(TokenizerBase): if isinstance(ids, int): ids = [ids] - return self.tokenizer.decode(ids) + return self.tokenizer.decode(ids, self._special_token_policy) def convert_ids_to_tokens( self, @@ -511,6 +512,9 @@ class MistralTokenizer(TokenizerBase): # See: https://github.com/vllm-project/vllm/pull/8640 # https://github.com/vllm-project/vllm/pull/9625 # if underlying tokenizeir is sentencepiece, we just add "�" - tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + tokens = [ + self.tokenizer.id_to_byte_piece(id, self._special_token_policy) + for id in ids + ] return tokens diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index eab8b03a4fbe5e772f13b0a5c6ec7f75141f6361..87159016067d83b2c1f45ae78898a4ff31630854 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -527,8 +527,8 @@ def random_uuid() -> str: class AsyncMicrobatchTokenizer: """Asynchronous tokenizer with micro-batching. - Pulls pending encode/decode requests from a queue and batches them - up to reduce overhead. A single-thread ThreadPoolExecutor is used + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used so the event loop stays responsive. """ @@ -675,18 +675,18 @@ class AsyncMicrobatchTokenizer: def _queue_key(self, op: str, kwargs: dict) -> tuple: """ Return a normalized key describing operation + kwargs. - + - `add_special_tokens`: {True/False} - `truncation`: {True/False} - - If `truncation` is False (`max_length` is None), + - If `truncation` is False (`max_length` is None), returns a key for a can_batch queue. - If `truncation` is True and `max_length` is None or equals `tokenizer.model_max_length`, returns a key for a can_batch queue. - Otherwise, returns a key for a cannot_batch queue. - + Examples: - Decode: ("decode",) - - Encode typical: + - Encode typical: ("encode", add_special_tokens, bool_truncation, max_length_label) - Fallback: ("encode", "other") """ @@ -951,6 +951,14 @@ def get_open_port() -> int: return _get_open_port() +def get_open_ports_list(count: int = 5) -> list[int]: + """Get a list of open ports.""" + ports = set() + while len(ports) < count: + ports.add(get_open_port()) + return list(ports) + + def _get_open_port() -> int: port = envs.VLLM_PORT if port is not None: @@ -1339,6 +1347,12 @@ def as_list(maybe_list: Iterable[T]) -> list[T]: return maybe_list if isinstance(maybe_list, list) else list(maybe_list) +def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]: + if isinstance(obj, str) or not isinstance(obj, Iterable): + obj = [obj] + return obj + + # `collections` helpers def is_list_of( value: object, @@ -1451,6 +1465,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None: torch.cuda.set_stream = _patched_set_stream +class _StreamPlaceholder: + + def __init__(self): + self.synchronize = lambda: None + + def current_stream() -> torch.cuda.Stream: """ replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. @@ -1470,11 +1490,18 @@ def current_stream() -> torch.cuda.Stream: # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process - - # fix computational precision issue in eager mode - # _current_stream_tls.value = torch.cuda.Stream( - # ) if current_platform.is_rocm() else torch.cuda.current_stream() - _current_stream_tls.value = torch.cuda.current_stream() + # if current_platform.is_rocm(): + # _current_stream_tls.value = torch.cuda.Stream() + if current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API") return _current_stream_tls.value @@ -1972,15 +1999,18 @@ class FlexibleArgumentParser(ArgumentParser): file_path = args[index + 1] - config_args = self._load_config_file(file_path) + config_args = self.load_config_file(file_path) - # 0th index is for {serve,chat,complete} + # 0th index might be the sub command {serve,chat,complete,...} # optionally followed by model_tag (only for serve) # followed by config args # followed by rest of cli args. # maintaining this order will enforce the precedence # of cli > config > defaults - if args[0] == "serve": + if args[0].startswith('-'): + # No sub command (e.g., api_server entry point) + args = config_args + args[0:index] + args[index + 2:] + elif args[0] == "serve": model_in_cli = len(args) > 1 and not args[1].startswith('-') model_in_config = any(arg == '--model' for arg in config_args) @@ -2003,7 +2033,7 @@ class FlexibleArgumentParser(ArgumentParser): return args - def _load_config_file(self, file_path: str) -> list[str]: + def load_config_file(self, file_path: str) -> list[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -2044,6 +2074,11 @@ class FlexibleArgumentParser(ArgumentParser): if isinstance(value, bool) and key not in store_boolean_arguments: if value: processed_args.append('--' + key) + elif isinstance(value, list): + if value: + processed_args.append('--' + key) + for item in value: + processed_args.append(str(item)) else: processed_args.append('--' + key) processed_args.append(str(value)) @@ -2634,7 +2669,7 @@ class PlaceholderModule(_PlaceholderBase): A placeholder object to use when a module does not exist. This enables more informative errors when trying to access attributes - of a module that does not exists. + of a module that does not exist. """ def __init__(self, name: str) -> None: @@ -3261,7 +3296,7 @@ class LazyLoader(types.ModuleType): """ LazyLoader module borrowed from Tensorflow https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py - with a addition of "module caching". + with an addition of "module caching". Lazily import a module, mainly to avoid pulling in large dependencies. Modules such as `xgrammar` might do additional side effects, so we @@ -3428,7 +3463,7 @@ def sha256_cbor_64bit(input) -> int: return full_hash & ((1 << 64) - 1) -def get_hash_fn_by_name(hash_fn_name: str) -> Callable: +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: """Get a hash function by name, or raise an error if the function is not found. Args: diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 861d9c0c0005d0c0cce9892cadd5e62cc23cc146..90cdd396209c73d9500cabd26756d6074a507926 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -27,41 +27,37 @@ def is_deep_gemm_supported() -> bool: is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) or current_platform.is_device_capability(100)) - return has_deep_gemm() and is_supported_arch + return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch @functools.cache -def is_blackwell_deep_gemm_e8m0_used() -> bool: +def is_deep_gemm_e8m0_used() -> bool: """Return ``True`` if vLLM is configured to use DeepGEMM " - "E8M0 scale on a Blackwell-class GPU. + "E8M0 scale on a Hopper or Blackwell-class GPU. """ - if not (envs.VLLM_USE_DEEP_GEMM): - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.") - return False - - if not has_deep_gemm(): - logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.") - return False - - if not envs.VLLM_USE_DEEP_GEMM_E8M0: - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.") + if not is_deep_gemm_supported(): + logger.debug_once( + "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.") return False _lazy_init() if _fp8_gemm_nt_impl is None: - logger.debug_once( - "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") + logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - enabled = (current_platform.is_cuda() - and current_platform.has_device_capability(100)) - if enabled: - logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.") - else: - logger.debug_once( - "DeepGEMM E8M0 disabled: not running on Blackwell GPU.") - return enabled + if current_platform.is_device_capability(100) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.") + return True + + if current_platform.is_device_capability(90) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.") + return True + + logger.info_once("DeepGEMM E8M0 disabled on current configuration.") + return False def _missing(*_: Any, **__: Any) -> NoReturn: @@ -127,20 +123,18 @@ def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _fp8_gemm_nt_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _grouped_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): @@ -148,9 +142,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) def _ceil_to_ue8m0(x: torch.Tensor): @@ -202,12 +194,19 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim +def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, + weight: torch.Tensor): + return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + + __all__ = [ "calc_diff", "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "per_block_cast_to_fp8", - "is_blackwell_deep_gemm_e8m0_used", + "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "should_use_deepgemm_for_fp8_linear", ] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 2e31b7bad74760863ab5897c96e8b742f9f80c79..fab134733d4fd49f553973e5e16211c3c1b10e2e 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -132,6 +132,11 @@ def has_nvidia_artifactory() -> bool: This checks connectivity to the kernel inference library artifactory which is required for downloading certain cubin kernels like TRTLLM FHMA. """ + # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when + # it's true, we could assume the cubins are available. + if envs.VLLM_HAS_FLASHINFER_CUBIN: + return True + try: # Use a short timeout to avoid blocking for too long response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5) @@ -174,21 +179,30 @@ def supports_trtllm_attention() -> tuple[bool, Optional[str]]: def use_trtllm_attention( + num_qo_heads: int, + num_kv_heads: int, num_tokens: int, max_seq_len: int, kv_cache_dtype: str, - num_qo_heads: Optional[int], - num_kv_heads: Optional[int], - attn_head_size: Optional[int], + q_dtype: torch.dtype, + is_prefill: bool, has_sinks: bool = False, ) -> bool: use_trtllm, env_value = supports_trtllm_attention() if not use_trtllm: return False - # Check if the dimensions are supported by TRTLLM decode attention - if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None - or num_qo_heads % num_kv_heads != 0): + if num_qo_heads % num_kv_heads != 0: + return False + + # Must use TRTLLM attention if query is FP8 quantized + if q_dtype == current_platform.fp8_dtype(): + logger.info_once("Using TRTLLM attention (query is quantized).") + return True + + # TRTLLM prefill attention does not support FP8 kv cache with + # non-quantized query + if is_prefill and kv_cache_dtype.startswith("fp8"): return False # If sinks are being used, we must use TRTLLM attention as it's @@ -251,6 +265,37 @@ if has_flashinfer(): dtype=dtype, device=A.device) + @torch.library.custom_op( + "vllm::bmm_fp8", + mutates_args=[], + device_types="cuda", + ) + def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import bmm_fp8 as bmm_fp8_ + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) + + @torch.library.register_fake("vllm::bmm_fp8", ) + def bmm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], + A.shape[1], + B.shape[2], + dtype=dtype, + device=A.device) + def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, @@ -279,6 +324,35 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, ) +def flashinfer_scaled_fp8_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert a.shape[1] == b.shape[0] + assert scale_a.numel() == 1 and scale_b.numel() == 1 + assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn + assert a.device.type == "cuda" and b.device.type == "cuda" + assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32 + assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda" + + output = bmm_fp8( + a.unsqueeze(0), + b.unsqueeze(0), + scale_a, + scale_b, + out_dtype, + "auto", + ).view(a.shape[0], b.shape[1]) + + if bias is not None: + output = output + bias + return output + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -290,6 +364,8 @@ __all__ = [ "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", + "supports_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", + "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 4cbe0f76e0067797a2d72e2a5adb9d835f35b8bf..457afb7e2c6ff65c753e29ba6dbb4a81787cea63 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Helper functions to work with nested JSON structures.""" + from collections.abc import Iterable from functools import reduce from typing import Callable, TypeVar, Union, overload @@ -8,8 +9,12 @@ from typing import Callable, TypeVar, Union, overload _T = TypeVar("_T") _U = TypeVar("_U") -JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], - tuple["JSONTree[_T]", ...], _T] +JSONTree = Union[ + dict[str, "JSONTree[_T]"], + list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], + _T, +] """A nested JSON structure where the leaves need not be JSON-serializable.""" @@ -78,3 +83,8 @@ def json_reduce_leaves( json_iter_leaves(value), initial, ) + + +def json_count_leaves(value: JSONTree[_T]) -> int: + """Count the number of leaves in a nested JSON structure.""" + return sum(1 for _ in json_iter_leaves(value)) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 9ed46331863c9b30e7fab3e2f01bce2b4a9062b3..ced8234a7b4336d7053e7c0a67d58645dbfdd346 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, # type: ignore output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -490,14 +491,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: shape = + [2, num_blocks, block_size * num_kv_heads * head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TorchSDPABackendImpl") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5d9210c4980c50ddc4a6dbfaf8069eea0c55bd79..5ac2577ec45036250b7a8a148d4bcbcf83367070 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -270,7 +270,7 @@ class FlashAttentionMetadataBuilder( num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -459,13 +459,6 @@ class FlashAttentionImpl(AttentionImpl): FlashAttentionBackend.validate_head_size(head_size) - if attn_type not in [ - AttentionType.DECODER, AttentionType.ENCODER_ONLY - ]: - raise NotImplementedError("Encoder/decoder cross-attention " - "is not implemented for " - "FlashAttentionImpl") - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -491,6 +484,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -498,7 +492,8 @@ class FlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -508,7 +503,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") @@ -531,7 +526,7 @@ class FlashAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, ): + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention(query[:num_actual_tokens], @@ -546,7 +541,11 @@ class FlashAttentionImpl(AttentionImpl): else: key_cache, value_cache = kv_cache - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if (self.kv_sharing_target_layer_name is None and key is not None + and value is not None): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -597,7 +596,7 @@ class FlashAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) if not current_platform.is_rocm(): flash_attn_varlen_func( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 991904229fd7f935eebd517b3e5d1124ddecd504..5fc3a1517b6905f4bea1faa3391d1274bf4a6505 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -6,21 +6,27 @@ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar, Optional, Union +import numpy as np import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) -from flashinfer.decode import (_get_range_buf, get_seq_lens, - trtllm_batch_decode_with_kv_cache) +from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache +from flashinfer.utils import FP4Tensor -import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import use_trtllm_attention +from vllm.utils.flashinfer import (supports_trtllm_attention, + use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block # yapf: disable @@ -31,10 +37,14 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills) +# yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + logger = init_logger(__name__) @@ -115,35 +125,6 @@ class FlashInferMetadata: num_actual_tokens: int # Number of tokens excluding padding. - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - qo_indptr_cpu: torch.Tensor - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan) - paged_kv_indptr_cpu: torch.Tensor - # The page indices of the paged kv cache (on device for plan) - paged_kv_indices: torch.Tensor - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] (CPU for plan) - paged_kv_last_page_len_cpu: torch.Tensor - # The number of query/output heads - num_qo_heads: int - # The number of key/value heads - num_kv_heads: int - # The dimension of the attention heads - head_dim: int - # Block size of vllm - page_size: int - # The data type of the paged kv cache - kv_data_type: torch.dtype # The data type of the query q_data_type: torch.dtype @@ -165,10 +146,6 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - shared_qo_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indices_cpu: Optional[torch.Tensor] = None - shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None @@ -177,10 +154,6 @@ class FlashInferMetadata: qo_indptr_gpu: Optional[torch.Tensor] = None paged_kv_indptr_gpu: Optional[torch.Tensor] = None - def __post_init__(self): - if self.head_dim is not None: - FlashInferBackend.validate_head_size(self.head_dim) - class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ @@ -193,13 +166,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.device = device self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + max_num_pages_per_req = cdiv(self.model_config.max_model_len, self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req @@ -213,12 +187,37 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self._decode_cudagraph_max_bs = min( max_num_reqs, self.compilation_config.max_capture_size) + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.head_dim = self.kv_cache_spec.head_size + FlashInferBackend.validate_head_size(self.head_dim) + self.page_size = self.kv_cache_spec.block_size + + self.enable_fusion = ( + self.compilation_config.pass_config.enable_attn_fusion) + self.q_data_type = self.model_config.dtype + self.cache_dtype = self.cache_config.cache_dtype + if self.cache_dtype.startswith("fp8"): + self.kv_cache_dtype = ( + FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype)) + # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled + if self.enable_fusion: + self.q_data_type = self.kv_cache_dtype + else: + self.kv_cache_dtype = self.kv_cache_spec.dtype + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, @@ -237,6 +236,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device="cpu", pin_memory=pin_memory) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() + self.paged_kv_indptr_buffer = torch.zeros_like( + self.paged_kv_indptr_cpu, pin_memory=pin_memory) self.paged_kv_indices_cpu = torch.zeros(max_num_pages, dtype=torch.int32, device="cpu", @@ -245,10 +247,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device="cpu", pin_memory=pin_memory) - - self.block_table_arange = torch.arange(max_num_pages_per_req, - dtype=torch.int32, - device=self.device) + self.paged_kv_last_page_len_np = ( + self.paged_kv_last_page_len_cpu.numpy()) def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -274,14 +274,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): decode_wrapper = self._decode_wrapper if decode_wrapper is None: - num_qo_heads = ( - self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config)) - num_kv_heads = self.vllm_config.model_config.get_num_kv_heads( - self.vllm_config.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - if use_cudagraph: paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] paged_kv_indices = self.paged_kv_indices @@ -298,7 +290,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indptr_buffer=paged_kv_indptr, paged_kv_indices_buffer=paged_kv_indices, paged_kv_last_page_len_buffer=paged_kv_last_page_len, - use_tensor_cores=use_tensor_cores) + # Tensor cores are enabled by default because the perf would be + # atleast as good as cuda cores for all attention ops in latest + # gpus. + use_tensor_cores=True, + ) # save the decode wrapper if use_cudagraph: @@ -314,36 +310,144 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): 2, self._get_workspace_buffer(), get_kv_cache_layout()) return self._cascade_wrapper - def _plan(self, attn_metadata: FlashInferMetadata): + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashInferMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ + split_decodes_and_prefills(common_attn_metadata) + + page_size = self.page_size + max_q_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() + block_table_tensor = common_attn_metadata.block_table_tensor + + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size + + use_cascade = common_prefix_len > 0 + if use_cascade: + # Grab the blocks of the shared prefix from the first request. + assert common_prefix_len % page_size == 0 + num_common_kv_blocks = common_prefix_len // page_size + + # Create CPU versions directly for cascade (no GPU versions needed) + shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device='cpu') + shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device='cpu') + shared_kv_page_indices_cpu = block_table_tensor[ + 0, :num_common_kv_blocks] + shared_kv_last_page_len_cpu = torch.tensor([page_size], + dtype=torch.int32, + device='cpu') + + # Remove the blocks of the shared prefix from all requests. + block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] + num_blocks_np -= num_common_kv_blocks + else: + shared_qo_indptr_cpu = None + shared_kv_page_indptr_cpu = None + shared_kv_page_indices_cpu = None + shared_kv_last_page_len_cpu = None + + # write self.paged_kv_indptr_cpu inplace (0-index is always 0) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1:num_reqs + 1], + ) + # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified + # after this line (e.g., for cuda graphs), we need to copy the data to + # self.paged_kv_indptr_buffer to avoid race condition. + self.paged_kv_indptr_buffer[:num_reqs + + 1] = (self.paged_kv_indptr_cpu[:num_reqs + + 1]) + paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] + paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1], + non_blocking=True) + + # write self.paged_kv_indices inplace + num_actual_pages = self.paged_kv_indptr_np[num_reqs] + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs, )]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) + + # write self.paged_kv_last_page_len_cpu inplace + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + paged_kv_last_page_len_np == 0, + page_size, + paged_kv_last_page_len_np, + ) + + # Check if any layer uses sinks (requires TRTLLM attention) + prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + has_sinks=self.has_sinks) + decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, + num_decode_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=False, + has_sinks=self.has_sinks) + + attn_metadata = FlashInferMetadata( + num_actual_tokens=num_actual_tokens, + q_data_type=self.q_data_type, + slot_mapping=common_attn_metadata.slot_mapping, + max_q_len=max_q_len, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table_tensor=block_table_tensor, + prefill_use_trtllm=prefill_use_trtllm, + decode_use_trtllm=decode_use_trtllm, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + use_cascade=use_cascade, + ) + + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( - [ - attn_metadata.shared_qo_indptr_cpu, - attn_metadata.qo_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indptr_cpu, - attn_metadata.paged_kv_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indices_cpu, - attn_metadata.paged_kv_indices - ], - [ - attn_metadata.shared_kv_last_page_len_cpu, - attn_metadata.paged_kv_last_page_len_cpu - ], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, + [shared_qo_indptr_cpu, qo_indptr_cpu], + [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], + [shared_kv_page_indices_cpu, paged_kv_indices], + [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, ) else: # Regular attention (common case). @@ -355,37 +459,34 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Decodes are first so prefills start after the last decode prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[ + assert qo_indptr_cpu[prefill_start:].shape[ 0] == num_prefills + 1 - assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[ + assert paged_kv_indptr_cpu[prefill_start:].shape[ 0] == num_prefills + 1 - assert attn_metadata.paged_kv_last_page_len_cpu[ - prefill_start:].shape[0] == num_prefills + assert paged_kv_last_page_len_cpu[prefill_start:].shape[ + 0] == num_prefills # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. - qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ - prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] - paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[ - prefill_start:] + qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ + prefill_start] + paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] if not attn_metadata.prefill_use_trtllm: attn_metadata.prefill_wrapper.plan( qo_indptr_cpu, paged_kv_indptr_cpu, - attn_metadata.paged_kv_indices, - attn_metadata. + paged_kv_indices, paged_kv_last_page_len_cpu[prefill_start:], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) @@ -405,7 +506,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Make sure paged_kv_indptr_cpu is not decreasing self.paged_kv_indptr_cpu[1 + num_decodes:1 + num_input_tokens].fill_( - attn_metadata. paged_kv_indptr_cpu[-1]) # Fill the remaining paged_kv_last_page_len_cpu with 1. # This is because flashinfer treats 0 as a full page @@ -425,150 +525,21 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): fast_plan_decode( attn_metadata.decode_wrapper, self.paged_kv_indptr_cpu[:num_input_tokens + 1], - attn_metadata.paged_kv_indices, + paged_kv_indices, self.paged_kv_last_page_len_cpu[:num_input_tokens], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, + seq_lens_cpu[:num_input_tokens], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, ) - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashInferMetadata: - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) - - page_size = self.kv_cache_spec.block_size - max_q_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.seq_lens_cpu.max() - seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu - block_table_tensor = common_attn_metadata.block_table_tensor - - block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size - - use_cascade = common_prefix_len > 0 - if use_cascade: - # Grab the blocks of the shared prefix from the first request. - assert common_prefix_len % page_size == 0 - num_common_kv_blocks = common_prefix_len // page_size - - # Create CPU versions directly for cascade (no GPU versions needed) - shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device='cpu') - shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device='cpu') - shared_kv_page_indices_cpu = block_table_tensor[ - 0, :num_common_kv_blocks] - shared_kv_last_page_len_cpu = torch.tensor([page_size], - dtype=torch.int32, - device='cpu') - - # Remove the blocks of the shared prefix from all requests. - block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - block_table_bounds_cpu -= num_common_kv_blocks - else: - shared_qo_indptr_cpu = None - shared_kv_page_indptr_cpu = None - shared_kv_page_indices_cpu = None - shared_kv_last_page_len_cpu = None - - max_num_blocks = block_table_bounds_cpu.max() - block_table_bounds = block_table_bounds_cpu.to(self.device, - non_blocking=True) - mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) - < block_table_bounds.unsqueeze(1)) - # write self.paged_kv_indices inplace - num_actual_pages = torch.sum(mask) - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - torch.masked_select(block_table_tensor[:, :max_num_blocks], - mask, - out=paged_kv_indices) - - # write self.paged_kv_indptr_cpu inplace (0-index is always 0) - torch.cumsum(block_table_bounds_cpu, - dim=0, - dtype=torch.int32, - out=self.paged_kv_indptr_cpu[1:1 + num_reqs]) - - paged_kv_last_page_len_cpu = seq_lens_cpu % page_size - # write self.paged_kv_last_page_len_cpu inplace - torch.where(paged_kv_last_page_len_cpu == 0, - torch.tensor(page_size), - paged_kv_last_page_len_cpu, - out=self.paged_kv_last_page_len_cpu[:num_reqs]) - - cache_dtype = self.cache_config.cache_dtype - if cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - cache_dtype) - else: - kv_cache_dtype = self.kv_cache_spec.dtype - - num_qo_heads = self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) - num_kv_heads = self.kv_cache_spec.num_kv_heads - head_dim = self.kv_cache_spec.head_size - - # Check if any layer uses sinks (requires TRTLLM attention) - has_sinks = self.global_hyperparameters.has_sinks - - # currently prefill trtllm attention does not support fp8 kv cache - prefill_use_trtllm = not cache_dtype.startswith("fp8") \ - and use_trtllm_attention( - num_prefill_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) - decode_use_trtllm = use_trtllm_attention( - num_decode_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim, has_sinks) - - attn_metadata = FlashInferMetadata( - num_actual_tokens=num_actual_tokens, - qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, - paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs], - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len_cpu=self. - paged_kv_last_page_len_cpu[:num_reqs], - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - page_size=page_size, - kv_data_type=kv_cache_dtype, - q_data_type=self.vllm_config.model_config.dtype, - slot_mapping=common_attn_metadata.slot_mapping, - max_q_len=max_q_len, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table_tensor=block_table_tensor, - prefill_use_trtllm=prefill_use_trtllm, - decode_use_trtllm=decode_use_trtllm, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, - use_cascade=use_cascade, - shared_qo_indptr_cpu=shared_qo_indptr_cpu, - shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, - shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, - shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, - ) - - self._plan(attn_metadata) - return attn_metadata def build_for_cudagraph_capture( @@ -622,6 +593,8 @@ class FlashInferImpl(AttentionImpl): self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) + self.window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -640,10 +613,20 @@ class FlashInferImpl(AttentionImpl): raise ValueError( "Sinks must have the same number of heads as the number of " f"heads in the layer. Expected {num_heads}, but got " - f"{sinks.shape[0]}." - ) + f"{sinks.shape[0]}.") self.sinks = sinks + self.support_trtllm_attn = (supports_trtllm_attention() + and num_heads % num_kv_heads == 0) + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + self.o_sf_scale: Optional[float] = None + + def fused_output_quant_supported(self, quant_key: QuantKey): + return (self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) + def forward( self, layer: torch.nn.Module, @@ -654,6 +637,7 @@ class FlashInferImpl(AttentionImpl): attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -661,26 +645,65 @@ class FlashInferImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: shape - - # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - - + kv_cache: KV cache tensor with different possible shapes: + - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashInferImpl") - if attn_metadata is None: # Profiling run. return output + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + # The attn+quant fusion happens when output_scale is provided. + if output_scale is None: + assert attn_metadata.q_data_type != FP8_DTYPE, \ + "Query can only be FP8 if output fusion happened." + assert output_block_scale is None, "output_block_scale "\ + "is not supported when fusion has not happened" + else: + assert attn_metadata.q_data_type == FP8_DTYPE, \ + "Query must be FP8 when attn+quant fusion happened." + assert (attn_metadata.prefill_use_trtllm and + attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + + if output.dtype == FP8_DTYPE: + assert output_block_scale is None, \ + "output_block_scale should not be provided for fp8 output" + elif output.dtype == FP4_DTYPE: + assert output_block_scale is not None, \ + "output_block_scale is required for nvfp4 output" + else: + raise ValueError(f"Unsupported output dtype: {output.dtype}") + + # TRTLLM attn kernel requires o scale to pass as a host scalar, + # store the o scale as a host scalar in warmup run with cuda graph + # not enabled + if layer._o_scale_float is None: + layer._o_scale_float = output_scale.cpu().item() + if output.dtype == FP8_DTYPE: + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + elif output.dtype == FP4_DTYPE: + self.o_sf_scale = layer._o_scale_float + + # Insert FP8 quant for query + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -718,9 +741,6 @@ class FlashInferImpl(AttentionImpl): self.kv_cache_dtype) kv_cache = kv_cache.view(torch_dtype) - window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) - # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] output_padded = output @@ -748,7 +768,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal - assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._window_left == self.window_left assert prefill_wrapper._logits_soft_cap == ( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale @@ -775,6 +795,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_prefill.is_contiguous() assert seq_lens_prefill.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape) + else: + assert self.o_sf_scale is None + out = output[num_decode_tokens:] + trtllm_batch_context_with_kv_cache( query=prefill_query, kv_cache=kv_cache_permute, @@ -783,14 +813,15 @@ class FlashInferImpl(AttentionImpl): seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, batch_size=attn_metadata.num_prefills, cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, - window_left=window_left, + window_left=self.window_left, sinks=self.sinks, - out=output[num_decode_tokens:], + o_sf_scale=self.o_sf_scale, + out=out, ) if num_decode_tokens > 0: @@ -800,7 +831,7 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper is not None if not attn_metadata.decode_use_trtllm: - assert decode_wrapper._window_left == window_left + assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale @@ -815,8 +846,8 @@ class FlashInferImpl(AttentionImpl): # decode_query may be non-contiguous decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer - block_tables_decode = attn_metadata.block_table_tensor[: - num_decode_tokens] + block_tables_decode = attn_metadata.\ + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -827,6 +858,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_decode.is_contiguous() assert seq_lens_decode.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape) + else: + assert self.o_sf_scale is None + out = output[:num_decode_tokens] + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -834,11 +875,12 @@ class FlashInferImpl(AttentionImpl): block_tables=block_tables_decode, seq_lens=seq_lens_decode, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, - window_left=window_left, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + window_left=self.window_left, sinks=self.sinks, - out=output[:num_decode_tokens], + o_sf_scale=self.o_sf_scale, + out=out, ) return output_padded @@ -848,6 +890,7 @@ def fast_plan_decode( indptr_cpu: torch.Tensor, indices: torch.Tensor, last_page_len_cpu: torch.Tensor, + seq_lens_cpu: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -925,9 +968,6 @@ def fast_plan_decode( kv_data_type = getattr(torch, kv_data_type) if isinstance( kv_data_type, str) else kv_data_type - if self.use_tensor_cores: - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " @@ -944,56 +984,29 @@ def fast_plan_decode( self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) - indptr_host = indptr_cpu - last_page_len_host = last_page_len_cpu - - if self.use_tensor_cores: - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, - page_size) - - try: - # Make sure we pass exactly 15 arguments for tensor core version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - ) - except Exception as e: - raise RuntimeError(f"Error in tensor core plan: {e}") from e - else: - try: - # Make sure we pass exactly 15 arguments for standard version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - window_left, - logits_soft_cap, - head_dim, - head_dim, - torch.empty(0, dtype=q_data_type), - torch.empty(0, dtype=kv_data_type), - ) - except Exception as e: - raise RuntimeError(f"Error in standard plan: {e}") from e + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + + try: + # Make sure we pass exactly 15 arguments for tensor core version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_cpu, + seq_lens_cpu, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + ) + except Exception as e: + raise RuntimeError(f"Error in tensor core plan: {e}") from e self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left @@ -1001,3 +1014,25 @@ def fast_plan_decode( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store(page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e599411b2d7e87c5db343095d5e32e3eb98c4093..d5b1c15e68d0e96d65c0bf603a1067cf57c1d12f 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict +"""Attention layer with FlexAttention.""" + from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union import torch +import torch._dynamo.decorators +import torch.nn.functional as F from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, _score_mod_signature, create_block_mask, @@ -16,13 +18,17 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, is_quantized_kv_cache) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead") @@ -36,6 +42,23 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: torch.arange(len(counts), device=device, dtype=torch.int32), counts) +def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): + difference = (multiple - (x.shape[dim] % multiple)) % multiple + if difference == 0: + return x + + dim = dim if dim >= 0 else x.ndim + dim + pad_list = [] + + for i in range(x.ndim - 1, dim - 1, -1): + if i == dim: + pad_list.extend([0, difference]) + else: + pad_list.extend([0, 0]) + + return F.pad(x, pad_list, mode="constant", value=0) + + class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -77,10 +100,10 @@ class FlexAttentionBackend(AttentionBackend): return False -# @torch.compile(fullgraph=True, mode="reduce-overhead") -def physical_to_logical_mapping( - block_table: torch.Tensor, - total_blocks: Optional[int] = None) -> torch.Tensor: +#@torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping(block_table: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, + total_blocks: int) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -114,13 +137,38 @@ def physical_to_logical_mapping( If a physical block is not mapped to by any logical block, its value in the result will be -1. + IMPORTANT: Garbage Value Protection + ──────────────────────────────────── + The block_table tensor may contain garbage values in unused positions + (beyond the actual sequence length). For example, if a sequence only + needs 3 blocks but the table has space for 8: + + block_table[0] = [10, 25, 7, 999, 1234, 888, ...] + ^^^^^^^^^^^^^^^^^^^^ + garbage values + + These garbage values can cause issues because: + 1. They may map to valid physical blocks by coincidence + 2. The scatter_ operation will assign them logical indices + 3. Later attention computations may incorrectly access these blocks + + To prevent this, we use seq_lens and block_size to mask out unused + entries, ensuring only valid block references are processed. Args: block_table: Tensor of shape [max_reqs, max_num_blocks] - mapping logical blocks to physical locations + mapping logical blocks to physical locations. May contain + garbage values in unused positions. + seq_lens: Tensor of sequence lengths for each request. Used to + determine how many blocks are actually needed per sequence. + block_size: Size of each block in tokens. Used with seq_lens to + compute the number of valid blocks per sequence. + total_blocks: Total number of physical blocks available Returns: - A tensor of shape [max_reqs, max_physical_block] + A tensor of shape [max_reqs, total_blocks] where each entry + physical_to_logical[req_id, physical_block] contains the logical + block index for that physical block, or -1 if unused. """ max_reqs, max_num_blocks = block_table.shape device = block_table.device @@ -130,17 +178,76 @@ def physical_to_logical_mapping( dtype=torch.long, device=device) - logical_indices = (torch.arange(max_num_blocks, - device=device).unsqueeze(0).expand( - max_reqs, -1)) + # Only process valid blocks to avoid garbage values + num_blocks_per_seq = cdiv(seq_lens, block_size) + mask = torch.arange(max_num_blocks, + device=device)[None, :] < num_blocks_per_seq[:, None] - physical_to_logical.scatter_(-1, block_table.to(torch.int64), - logical_indices) - # TODO Confirm - Seems like block 0 is always empty so we reset it manually + valid_block_table = torch.where(mask, block_table, 0) + valid_logical_indices = torch.where( + mask, + torch.arange(max_num_blocks, device=device)[None, :], 0) + + physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64), + valid_logical_indices) + # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical +def unique_static_unsorted( + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots +) -> torch.Tensor: + """ + - Keeps the first occurrence of each non-zero value while preserving order, + then left-packs those uniques and fills the rest with `pad_val`. + - Returns (packed, keep_mask) with the *same shape* as `x`. + - Requires that all values be in the range [0, M] + - Skips ignored_val + + Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory. + + Example: + x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1] + """ + if not (-1 <= pad_val <= M): + raise ValueError("`pad_val` must lie in [-1, M]") + + # ── move `dim` to the end so we can treat tensor as [B, N] ────────── + dim = dim % x.ndim + x_perm = x.movedim(dim, -1) # shape [..., N] + B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1] + x_flat = x_perm.reshape(B, N) # [B, N] + + device = x.device + idx = torch.arange(N, device=device).expand(B, N) # per-row indices + + # ── build first-occurrence table for every v ∈ [0, M] ─────────────── + first_idx = torch.full((B, M + 1), N, device=device) # “∞” + # scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i + first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") + + # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat) + ) # [B, N] + + # ── left-pack uniques into a fresh tensor ─────────────────────────── + dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go + packed_flat = torch.full_like(x_flat, pad_val) + + rows, src_cols = torch.nonzero(keep, as_tuple=True) + packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + + # ── restore original layout ───────────────────────────────────────── + packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) + return packed + + def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor): return q_idx >= kv_idx @@ -170,6 +277,7 @@ class FlexAttentionMetadata: num_reqs: int physical_to_logical: torch.Tensor decode_offset: torch.Tensor + num_blocks_per_seq: torch.Tensor # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -179,6 +287,46 @@ class FlexAttentionMetadata: block_mask: Optional[BlockMask] = None score_mod: Optional[_score_mod_signature] = None logical_mask_mod: _mask_mod_signature = causal_mask_mod + doc_ids: Optional[torch.Tensor] = None + direct_build: bool = True + q_block_size: int = 16 + kv_block_size: int = 16 + transformed_score_mod: Optional[_score_mod_signature] = None + + def _convert_physical_to_logical( + self, + request_lookup: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert physical indices to logical indices for both query and kv. + + NB is_within_lower_bound: do sequences start on block_boundaries? + + Returns: + tuple of (is_valid, logical_q_idx, logical_kv_idx) + """ + # Map query indices to corresponding request indices + q_req = request_lookup[q_idx] + + # Convert physical KV indices to logical indices + physical_kv_block = physical_kv_idx // self.block_size + physical_kv_offset = physical_kv_idx % self.block_size + logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] + logical_kv_idx = (logical_block_idx * self.block_size + + physical_kv_offset) + + # Determine valid kv indices + live_block = logical_block_idx >= 0 + within_upper_bound = logical_kv_idx < self.seq_lens[q_req] + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + # Convert physical query indices to logical indices + local_q_idx = q_idx - self.query_start_loc[q_req] + logical_q_idx = local_q_idx + self.decode_offset[q_req] + + return is_valid, logical_q_idx, logical_kv_idx def get_causal_mask_mod(self) -> _mask_mod_signature: """Creates the mask_mod function for FlexAttention. @@ -191,11 +339,8 @@ class FlexAttentionMetadata: With this info we create the "logical" indices that are passed to mask_mod functions. This allows mask mod functions to be agnostic to layout of the query and key/value tensors. - - TODO is_within_lower_bound: do sequences start on block_boundaries? """ - # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + assert self.doc_ids is not None def final_mask_mod( b: torch.Tensor, @@ -203,27 +348,9 @@ class FlexAttentionMetadata: q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - # Map query indices to corresponding request indices - q_req = request_lookup[q_idx] - - # Convert physical KV indices to logical indices - physical_kv_block = physical_kv_idx // self.block_size - physical_kv_offset = physical_kv_idx % self.block_size - logical_block_idx = self.physical_to_logical[q_req, - physical_kv_block] - logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501 - - # Determine valid kv indices - live_block = logical_block_idx >= 0 - within_upper_bound = logical_kv_idx < self.seq_lens[q_req] - within_lower_bound = logical_kv_idx >= 0 - - is_valid = live_block & within_upper_bound & within_lower_bound - - # Convert physical query indices to logical indices - local_q_idx = q_idx - self.query_start_loc[q_req] - logical_q_idx = local_q_idx + self.decode_offset[q_req] - + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + self.doc_ids, q_idx, physical_kv_idx) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -236,7 +363,7 @@ class FlexAttentionMetadata: def get_bidirectional_mask_mod(self) -> _mask_mod_signature: """Creates the encoder mask_mod function for FlexAttention. - Since the encoder bidirectional attention doesn't run with + Since the encoder bidirectional attention doesn't run with KV cache, this function creates a mask based on the packed query sequences. """ @@ -253,6 +380,97 @@ class FlexAttentionMetadata: return final_mask_mod + def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: + """Creates the transformed score_mod function for FlexAttention. + + This function wraps the user's score_mod to handle physical-to-logical + index conversion, similar to how get_mask_mod works for mask functions. + """ + if self.score_mod is None: + return None + + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + user_score_mod = self.score_mod + + def transformed_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx) + + return torch.where( + is_valid, + user_score_mod(score, + b, + h, + logical_q_idx, + logical_kv_idx, + physical_q=q_idx), -float('inf')) + + return transformed_score_mod + + def _build_block_mask_direct(self) -> BlockMask: + """Direct block mask construction for standard causal attention. + + This method constructs the block mask directly using + BlockMask.from_kv_blocks which is much more efficient than the + generic create_block_mask approach. + + The direct path works as follows: + 1. For each query token, fetch blocks from block_table using max_seq_len + (this fetches more blocks than needed for shorter sequences) + 2. Group query tokens into chunks of q_block_size + 3. For each group, deduplicate the blocks using unique_static_unsorted + 4. Create BlockMask using the deduplicated block indices + + Over-estimation occurs when a group of q_block_size tokens contains + multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for + each sequence represented in the group, even though individual query + tokens may only need a subset of those blocks based on causal masking + and their position. + + """ + page_to_block_ratio = self.kv_block_size // self.block_size + if page_to_block_ratio != 1: + raise ValueError( + f"FlexAttention currently requires the cache block size " + f"({self.block_size}) to be equal to the kv_block_size " + f"({self.kv_block_size}). Please check your model's " + f"configuration.") + + used_pages = self.block_table[ + self.doc_ids, :cdiv(self.max_seq_len, self.block_size)] + used_pages_padded = pad_to_multiple(used_pages, + multiple=self.q_block_size, + dim=0) + used_pages_padded = used_pages_padded.reshape( + used_pages_padded.shape[0] // self.q_block_size, -1) + used_pages_padded = used_pages_padded // page_to_block_ratio + kv_indices = unique_static_unsorted((used_pages_padded.long()), + M=self.num_blocks).to(torch.int32) + + kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) + block_mask_kwargs = { + "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), + "kv_num_blocks": kv_num_blocks[None, None], + "kv_indices": kv_indices[None, None], + "full_kv_num_blocks": None, + "full_kv_indices": None, + "BLOCK_SIZE": (self.q_block_size, self.kv_block_size), + "mask_mod": self.mask_mod, + } + + # compute_q_blocks parameter is available in PyTorch 2.9+ + if is_torch_equal_or_newer("2.9.0.dev0"): + block_mask_kwargs["compute_q_blocks"] = False + return BlockMask.from_kv_blocks(**block_mask_kwargs) + def build_block_mask(self) -> BlockMask: if self.causal: mask_mod = self.get_causal_mask_mod() @@ -267,6 +485,7 @@ class FlexAttentionMetadata: self.num_actual_tokens, kv_len, device=self.block_table.device, + BLOCK_SIZE=(self.q_block_size, self.kv_block_size), ) def __post_init__(self): @@ -275,8 +494,21 @@ class FlexAttentionMetadata: assert self.cu_prefix_query_lens is None, "Not implemented yet." assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." + # Create a lookup mapping from query indices -> request number + self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.num_blocks = self.total_cache_tokens // self.block_size - self.block_mask = self.build_block_mask() + + if self.causal: + self.mask_mod = self.get_causal_mask_mod() + else: + self.mask_mod = self.get_bidirectional_mask_mod() + + self.transformed_score_mod = self.get_transformed_score_mod() + + if self.direct_build and self.causal: + self.block_mask = self._build_block_mask_direct() + else: + self.block_mask = self.build_block_mask() class FlexAttentionMetadataBuilder( @@ -287,15 +519,24 @@ class FlexAttentionMetadataBuilder( self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config + self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( - vllm_config.parallel_config) + self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( - vllm_config.parallel_config) + self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.device = device + self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") + self.q_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False def build(self, common_prefix_len: int, @@ -305,11 +546,12 @@ class FlexAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + num_blocks_per_seq = cdiv(seq_lens, self.block_size) use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -320,12 +562,15 @@ class FlexAttentionMetadataBuilder( block_size = self.kv_cache_spec.block_size max_possible_seq_len = self.model_config.max_model_len - total_cache_tokens = self.cache_config.num_gpu_blocks * block_size + num_gpu_blocks = self.cache_config.num_gpu_blocks + + assert num_gpu_blocks is not None, \ + "FlexAttention requires num_gpu_blocks to be set" + total_cache_tokens = (num_gpu_blocks * block_size) inverse_block_table = physical_to_logical_mapping( - block_table_tensor, self.cache_config.num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks) - # Get the original offset tensor offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( self.device, non_blocking=True) @@ -349,9 +594,16 @@ class FlexAttentionMetadataBuilder( physical_to_logical=inverse_block_table, total_cache_tokens=total_cache_tokens, decode_offset=offset_tensor, + num_blocks_per_seq=num_blocks_per_seq, + direct_build=self.direct_build, + q_block_size=self.q_block_size, + kv_block_size=self.kv_block_size, ) return out + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class FlexAttentionImpl(AttentionImpl): sliding_window: Optional[tuple[int, int]] @@ -370,6 +622,7 @@ class FlexAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -398,6 +651,7 @@ class FlexAttentionImpl(AttentionImpl): raise NotImplementedError( "FlexAttention does not support logits soft cap yet.") + assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: @@ -405,7 +659,6 @@ class FlexAttentionImpl(AttentionImpl): "FlexAttention does not support kv sharing yet.") FlexAttentionBackend.validate_head_size(head_size) - if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet") @@ -428,6 +681,7 @@ class FlexAttentionImpl(AttentionImpl): attn_metadata: FlexAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -435,13 +689,14 @@ class FlexAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlexAttentionImpl") @@ -492,35 +747,49 @@ class FlexAttentionImpl(AttentionImpl): # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) - # default M=64, N=64 may run out of shared memory on some GPUs - # TODO: Explicit configs for each GPU? - # Not sure how to calculate the shared memory requirement - extra_kernel_options = defaultdict[str, int](lambda: 64) - if query.dtype == torch.float32: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 - if current_platform.is_cuda(): - device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin - if max_shared_memory < 144 * 1024: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 + assert attn_metadata.block_mask is not None + block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE + kernel_options = get_kernel_options(query, block_m, block_n, + attn_metadata.direct_build) out = flex_attention_compiled( query, key_tensor, value_tensor, - attn_metadata.score_mod, + attn_metadata.transformed_score_mod, attn_metadata.block_mask, self.scale, enable_gqa=enable_gqa, - kernel_options={ - "FORCE_USE_FLEX_ATTENTION": True, - **extra_kernel_options - }, + kernel_options=kernel_options, ) # Flex doesn't have an out variant today, rely on epilogue fusion out = out.permute(0, 2, 1, 3).squeeze(0) output[:num_actual_tokens, :, :].copy_(out) return output + + +def get_kernel_options(query, block_m, block_n, + use_direct_build: bool) -> dict[str, Union[int, bool]]: + kernel_options: dict[str, Union[int, bool]] = { + "FORCE_USE_FLEX_ATTENTION": True, + } + if use_direct_build: + kernel_options["BLOCK_M"] = block_m + kernel_options["BLOCK_N"] = block_n + return kernel_options + else: + kernel_options["BLOCK_M"] = 64 + kernel_options["BLOCK_N"] = 64 + if query.dtype == torch.float32: + kernel_options["BLOCK_M"] = 32 + kernel_options["BLOCK_N"] = 32 + # if current_platform.is_cuda(): + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties() + max_shared_memory = device_props.shared_memory_per_block_optin + if max_shared_memory < 144 * 1024: + kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2 + kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2 + + return kernel_options diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 6cdc509083ae9c00df04de210d41ebcb2a533d06..97a1aa86dda0ded755ae11ce2fbd06b0e3dad03a 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,16 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class Mamba1AttentionBackend(AttentionBackend): @@ -31,24 +31,11 @@ class Mamba1AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int + num_padded_decodes: int class Mamba1AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba1AttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 - - def __init__( - self, - kv_cache_spec: AttentionSpec, - vllm_config: VllmConfig, - device: torch.device, - layer_names: list[str], - ): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): def build( self, @@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder( decode_threshold=1)) has_initial_states = None + padded_decodes = num_decodes if num_prefills > 0: has_initial_states = context_lens_tensor > 0 + elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph): + state_indices_for_decode = state_indices_tensor[:num_decodes] + padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_for_decode, non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:padded_decodes] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID return Mamba1AttentionMetadata( query_start_loc=query_start_loc, @@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder( num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, + num_padded_decodes=padded_decodes, ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ace078e2b27c63b4e92c889089e7f1be3d7cfa93..ed30884fdbc94f16f74ee2535458927e82ef1048 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -2,18 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, @@ -88,29 +88,14 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba2AttentionMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - - reorder_batch_threshold: ClassVar[int] = 1 + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") - self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) - self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), - dtype=torch.int32, - device=device, - ) def build(self, common_prefix_len: int, @@ -187,19 +172,3 @@ class Mamba2AttentionMetadataBuilder( state_indices_tensor=state_indices_tensor, ) return attn_metadata - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with Mamba. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..07ef7cb69a160cd5b49c172416569bb1da4e9062 --- /dev/null +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import abc +from typing import ClassVar, TypeVar + +import torch + +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + +M = TypeVar("M") + + +class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): + reorder_batch_threshold: ClassVar[int] = 1 + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + self.device = device + self.vllm_config = vllm_config + self.layer_names = layer_names + + self.compilation_config = vllm_config.compilation_config + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) \ No newline at end of file diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py deleted file mode 100644 index d3a0c63c5e964b850cad528d69fb8738a0678c1e..0000000000000000000000000000000000000000 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend -from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend - - -def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: - if mamba_type == "mamba1": - return Mamba1AttentionBackend - if mamba_type == "mamba2": - return Mamba2AttentionBackend - if mamba_type == "linear_attention": - return LinearAttentionBackend - - raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e2486ed3313eebcfcfa125f0beb18298a966b91b..c29b4fb96e40243aa7d3c11652644fef8a0062eb 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. +* Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. @@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v -) +) return spda_o @ W_O NOTE: in the actual code, @@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: -MCC Max chunk of context to process per iter, computed dynamically, +MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ @@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention( new_v, casual=True, return_softmax_lse=True -) +) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): @@ -197,6 +197,7 @@ from typing import ClassVar, Generic, Optional, TypeVar, Union import torch import os +from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops @@ -208,6 +209,7 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig +from vllm.distributed.parallel_state import is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -240,6 +242,28 @@ try: except ImportError: flashinfer_available = False + +def is_rocm_aiter_fp8bmm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_FP8BMM \ + and envs.VLLM_ROCM_USE_AITER + + +if is_rocm_aiter_fp8bmm_enabled(): + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + as aiter_triton_fp8_bmm) + + def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -423,7 +447,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -433,30 +456,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() @@ -660,8 +681,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and num_prefills > 0 \ - and max_context_len_cpu > 0: + if max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code @@ -675,8 +695,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) @@ -1047,10 +1068,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return x def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -1101,16 +1133,57 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1123,12 +1196,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.chunked_context.starts[i], ) @@ -1179,6 +1254,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: assert attn_metadata.prefill is not None @@ -1204,7 +1280,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_context: suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1228,6 +1304,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, + layer: AttentionLayer, ) -> torch.Tensor: raise NotImplementedError @@ -1241,10 +1318,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_metadata: M, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLACommonImpl") @@ -1255,6 +1333,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # same expert outputs. return output.fill_(0) + fp8_attention = self.kv_cache_dtype.startswith("fp8") + num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs @@ -1289,10 +1369,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): scale=layer._k_scale, ) + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: assert attn_metadata.decode is not None @@ -1300,12 +1383,35 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + else: + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape([ + ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] + ]), layer._q_scale) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape( + [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype) + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) return output_padded \ No newline at end of file diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 024f38201970b7458458aee18beb967dba056267..8a17d3a49278329ce26342de8e9b3b57dfc63a67 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -7,7 +7,7 @@ from typing import ClassVar, Optional import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -115,7 +115,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): self._use_old_cutlass_mla = False force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) if force_old_cutlass: - logger.warning("Forcing old cutlass mla kernel") + logger.warning_once("Forcing old cutlass mla kernel") self._use_old_cutlass_mla = True # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging @@ -123,8 +123,8 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): # FORCE_NUM_KV_SPLITS=1 force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) if force_num_kv_splits: - logger.warning("Forcing num_kv_splits to %d", - int(force_num_kv_splits)) + logger.warning_once("Forcing num_kv_splits to %d", + int(force_num_kv_splits)) self._num_kv_splits = int(force_num_kv_splits) else: self._num_kv_splits = -1 # => Auto-detect @@ -278,6 +278,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: if self._use_old_cutlass_mla: # TODO: Remove the old cutlass MLA kernel after more extensive diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 8d43f9fa8f1df042510443bdf355a6e3933a2634..1c50144d4790097dc3f72afd558316a481c82774 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,8 +6,7 @@ from typing import ClassVar, Optional import torch -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -166,19 +165,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): "are not implemented for " "FlashMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - if self.kv_cache_dtype != "fp8": - raise NotImplementedError( - "FlashMLA with other KV cache not yet supported") - def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, - k_scale = None, - kv_cache_dtype = "auto", + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -197,8 +190,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, - k_scale = k_scale, - kv_cache_dtype = kv_cache_dtype, + descale_q=layer._q_scale.reshape(1), + descale_k=layer._k_scale.reshape(1), ) return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 082c7e6f7c62edc3e37a4c577278a5835c379525..870cc600388e7feff10cc72ea1966af20524cd96 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -7,6 +7,7 @@ from typing import ClassVar, Optional import torch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv @@ -221,6 +222,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 700fce68953e5f78bf5cfe5afd2a8dae16632bef..f2974ed668d99dc9e250729a3bafe89ce0dc9b2f 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -6,7 +6,7 @@ from typing import Optional import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention @@ -127,6 +127,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 9b122136afb7fff5e45e4219264692729bc10ba7..26f9abf13d0ed9efc811b437506801ae660f361b 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,12 +5,6 @@ from dataclasses import dataclass from typing import Optional import torch -import torch_xla.core.xla_builder as xb -import torch_xla.experimental.custom_kernel # noqa: F401 -# Required to register custom ops. -from torch.library import impl -from torch_xla._internal.jax_workarounds import requires_jax -from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { "uint8": torch.uint8, } +try: + import tpu_commons # noqa: F401 +except ImportError: + # Lazy import torch_xla + import torch_xla.core.xla_builder as xb + import torch_xla.experimental.custom_kernel # noqa: F401 + from torch.library import impl + from torch_xla._internal.jax_workarounds import requires_jax + from torch_xla.experimental.custom_kernel import XLA_LIB + + @requires_jax + def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, num_slices_per_block: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax( + kv_cache_update, + (kv, slot_mapping, kv_cache, num_kv_update_slices), { + "page_size": page_size, + "num_slices_per_block": num_slices_per_block + }) + return new_kv_cache + + + XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ + "int num_slices_per_block)" \ + "-> Tensor", ) + + @impl(XLA_LIB, "kv_cache_update_op", "XLA") + def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + num_kv_update_slices, page_size, + num_slices_per_block) + return new_kv_cache + + @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") + def kv_cache_update_op_non_xla(kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + return kv_cache + class PallasAttentionBackend(AttentionBackend): @@ -182,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl): attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -189,12 +235,13 @@ class PallasAttentionBackendImpl(AttentionImpl): query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads * 2, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for PallasAttentionBackendImpl") @@ -283,7 +330,7 @@ def write_to_kv_cache( Args: key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size] num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape @@ -313,46 +360,6 @@ def write_to_kv_cache( kv_cache.copy_(new_kv_cache) -@requires_jax -def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int): - from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update - new_kv_cache = xb.call_jax( - kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) - return new_kv_cache - - -XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \ - "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \ - "-> Tensor", ) - - -@impl(XLA_LIB, "kv_cache_update_op", "XLA") -def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - num_kv_update_slices, page_size, - num_slices_per_block) - return new_kv_cache - - -@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") -def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: - return kv_cache - - # We can move this function to a common utils file if it's also useful for other # hardware. def dtype_bits(dtype: torch.dtype): diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 7d09ac0a4a3a1364f9a0526f38c1eab1fc4c55df..173a0a255e49146c2d175effefd47594e5540d06 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec @@ -231,7 +232,7 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True + cudagraph_support = AttentionCGSupport.ALWAYS def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -269,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor @@ -420,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl): attn_metadata: AiterFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -427,7 +429,8 @@ class AiterFlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -437,7 +440,7 @@ class AiterFlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d80ced8ec876a100260d49c5f88248f6e33ec454 --- /dev/null +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class ShortConvAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: + return ShortConvAttentionMetadataBuilder + + +@dataclass +class ShortConvAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + + query_start_loc: torch.Tensor + has_initial_states: torch.Tensor + state_indices_tensor: torch.Tensor # shape: [batch,] + + # For causal_conv1d + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None + + +class ShortConvAttentionMetadataBuilder( + AttentionMetadataBuilder[ShortConvAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> ShortConvAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + has_initial_states = None + if num_prefills > 0: + #[batch,] + has_initial_states_cpu = ( + common_attn_metadata. + num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + has_initial_states = has_initial_states_cpu.to( + query_start_loc.device) + + attn_metadata = ShortConvAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + has_initial_states=has_initial_states, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata \ No newline at end of file diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 5d10e9e26082d0bf521701f89b64cb134605fcad..b96d957a150b56a70fd22e588536168f0ba7c91b 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder( q_start_loc = common_attn_metadata.query_start_loc max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl): attn_metadata: TreeAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with TreeAttention. @@ -361,14 +362,15 @@ class TreeAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TreeAttentionImpl") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 48a9af3decac0422367106898b00fea74244cc31..a37a7f6811ef9ffb000ebdf34dcf08944955402b 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor @@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -284,14 +285,15 @@ class TritonAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TritonAttentionImpl") diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 4da1c8155b41af8394e041be5612bd33e90ff68a..9466238d2631aa2b51e7b3e7c7321a5ac9b95b7b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,12 +4,13 @@ import abc import enum import functools from abc import abstractmethod -from dataclasses import dataclass, make_dataclass -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, +from dataclasses import dataclass, fields, make_dataclass +from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, TypeVar) import numpy as np import torch +from typing_extensions import runtime_checkable from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv @@ -20,7 +21,8 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) @@ -58,7 +60,9 @@ class CommonAttentionMetadata: """Total number of tokens in batch""" max_query_len: int """Longest query in batch""" - + max_seq_len: int + """Longest context length in batch""" + block_table_tensor: torch.Tensor num_speculative_tokens: int = 0 @@ -69,6 +73,10 @@ class CommonAttentionMetadata: causal: bool = True + # Needed by FastPrefillAttentionBuilder + logits_indices_padded: Optional[torch.Tensor] = None + num_logits_indices: Optional[int] = None + @dataclass class UbatchSlice: @@ -112,6 +120,7 @@ def _make_metadata_with_slice( seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ request_slice] @@ -133,6 +142,7 @@ def _make_metadata_with_slice( num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, ) @@ -207,6 +217,23 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + input_batch: input batch + scheduler_output: scheduler output. + + Returns: + True if the batch was modified, False otherwise. + """ + raise NotImplementedError + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata) -> M: """ @@ -469,8 +496,9 @@ def make_local_attention_virtual_batches( attn_chunk_size)[arange > 0] # convert from q_seqlens to cu_seqlens_q - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ - .astype(np.int32) + cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) + np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:]) + cu_seqlens_q_local[0] = 0 # compute the seqlens_k_local, # basically a full local attention block for all but the last block in each @@ -513,11 +541,10 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices= np.broadcast_to( - np.arange(pages_per_local_batch, dtype=np.int32), - (virtual_batches, pages_per_local_batch)) \ - + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + block_indices = (block_starts[:, None] + + np.arange(pages_per_local_batch, dtype=np.int32)) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - + 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) block_table_local = block_table[batch_indices, block_indices]\ @@ -525,6 +552,7 @@ def make_local_attention_virtual_batches( query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) + max_seq_len = int(seq_lens_cpu.max()) return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, @@ -536,39 +564,74 @@ def make_local_attention_virtual_batches( num_reqs=len(seq_lens_cpu), num_actual_tokens=common_attn_metadata.num_actual_tokens, max_query_len=seqlens_q_local.max(), + max_seq_len=max_seq_len, block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, ) -def subclass_attention_metadata_builder( - name_prefix: str, - builder_cls: type[AttentionMetadataBuilder[M]], - build_preprocess_fn: Callable[[CommonAttentionMetadata], - CommonAttentionMetadata], -) -> type[AttentionMetadataBuilder[M]]: - """ - Return a new subclass of `builder_cls` whose .build(...) method - first calls build_preprocess_fn(common_attn_metadata) on the metadata. - """ - name: str = name_prefix + builder_cls.__name__ # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False): - return builder_cls.build(self, common_prefix_len, - build_preprocess_fn(common_attn_metadata), - fast_build) - - Wrapped = type( - name, - (builder_cls, ), # inherit from the original - { - "build": build, - }) - return Wrapped # type: ignore +def make_kv_sharing_fast_prefill_common_attn_metadata( + common_attn_metadata: CommonAttentionMetadata, +) -> CommonAttentionMetadata: + if common_attn_metadata.max_query_len == 1: + # All requests are decode (assume 1 token for now) + # Skip computing fast prefill path + return common_attn_metadata + + assert common_attn_metadata.logits_indices_padded is not None + assert common_attn_metadata.num_logits_indices is not None + + logits_indices_padded = common_attn_metadata.logits_indices_padded + num_logits_indices = common_attn_metadata.num_logits_indices + # Get rid of CUDAGraph padding, if any + logits_indices = logits_indices_padded[:num_logits_indices] + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + query_start_loc_cpu=decode_query_start_loc.to("cpu", + non_blocking=True), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + causal=True, + ) + return common_attn_metadata def subclass_attention_backend( @@ -727,13 +790,56 @@ def subclass_attention_metadata( return Wrapped -def make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls: Any, ) -> Any: - """ - Return a new subclass of `metadata_cls` for fast prefill - """ - return subclass_attention_metadata( - name_prefix="KVSharingFastPrefill", - metadata_cls=metadata_cls, - fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS, - ) +@runtime_checkable +class KVSharingFastPrefillMetadata(Protocol): + logits_indices_padded: torch.Tensor + num_logits_indices: int + + +def create_fast_prefill_custom_backend( + prefix: str, + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: + + underlying_builder = underlying_attn_backend.get_builder_cls() + + class FastPrefillAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata =\ + make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) + metadata = super().build(common_prefix_len, + new_common_attn_metadata, fast_build) + + class KVSharingFastPrefillAttentionMetadata( + metadata.__class__, # type: ignore + KVSharingFastPrefillMetadata): + + def __init__(self, metadata, common_attn_metadata): + # Shallow copy all fields in metadata cls + for field in fields(metadata.__class__): + setattr(self, field.name, + getattr(metadata, field.name)) + + # Set additional fields that will be used in model code + assert (common_attn_metadata.logits_indices_padded + is not None + and common_attn_metadata.num_logits_indices + is not None) + self.logits_indices_padded = \ + common_attn_metadata.logits_indices_padded + self.num_logits_indices = \ + common_attn_metadata.num_logits_indices + + return KVSharingFastPrefillAttentionMetadata( + metadata, common_attn_metadata) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=FastPrefillAttentionBuilder) + + return attn_backend diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index fe732c601770296a710964a00f3c0d2ed0300a1c..7f888c113574312265d459bbd8f5087de0ddefc9 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder( q_seqlens = torch.diff(q_start_loc) max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl): attn_metadata: XFormersAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with XFormers. @@ -329,14 +330,15 @@ class XFormersAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersAttentionImpl") diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 839297135fe0aff6f86ff5d38e54b6f996cbd8d7..b537cac8e1d72e800bb8aca0ee7933f4e9c355c6 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -4,8 +4,9 @@ from collections import defaultdict from collections.abc import Iterable from typing import Optional -from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, - BlockStored, KVCacheEvent) +from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, + BlockRemoved, BlockStored, + KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, FreeKVCacheBlockQueue, KVCacheBlock) @@ -156,6 +157,7 @@ class BlockPool: block_size=block_size, lora_id=request.lora_request.id if request.lora_request else None, + medium=MEDIUM_GPU, )) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: @@ -218,7 +220,8 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()])) + BlockRemoved(block_hashes=[block_hash.get_hash_value()], + medium=MEDIUM_GPU)) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: @@ -298,7 +301,12 @@ class BlockPool: Returns: The KV cache usage (between 0.0 and 1.0). """ - return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) + + # Subtract 1 to account for null block. + total_gpu_blocks = self.num_gpu_blocks - 1 + if not total_gpu_blocks: + return 0 + return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks) def take_events(self) -> list[KVCacheEvent]: """Atomically takes all events and clears the queue. diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index faf5c132f8640ebd772a3ff81c322a3d44694ff7..bd2ec036834b26ba1c55415ad2ca24846eb84221 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from collections.abc import Mapping from typing import TYPE_CHECKING from vllm.logger import init_logger @@ -31,34 +33,52 @@ class EncoderCacheManager: within requests, allowing for fine-grained memory management and enabling chunked processing of multimodal inputs. - Note that no caching is shared between requests at this time. If the same - input is used across multiple requests, it will be reprocessed for each - request. + Cache is enabled to share embeddings of same multimodal data + item (identified by their hash value) between different requests, + and eviction takes place at allocation time when there's no free + space for new embeddings. + Oldest cached embeddings with no request referenced will be first evicted. Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. Attributes: - cache_size: Total cache capacity in encoder tokens - num_free_slots: Current available cache capacity in encoder tokens - cached: Mapping from request_id to set of cached input_ids for that - request - freed: List of (request_id, input_id) pairs that were recently freed. - This is cleared after every call to get_freed_ids(). + cache_size: Total cache capacity in encoder tokens. + num_free_slots: Current available cache capacity in encoder tokens. + num_freeable_slots: Capacity that can be immediately reclaimed by + evicting entries with zero references (in encoder tokens). + cached: Mapping from mm_hash to a set of request IDs that currently + reference the cached entry. If the set is empty, the entry exists + but is not referenced by any request and is eligible for + reclamation. + freeable: List of tuples (mm_hash, num_tokens) representing entries + whose no current running request is needed and that can be freed to + make space when needed. + freed: List of mm_hash strings that were actually evicted since the + last call to get_freed_mm_hashes(). This list is cleared on return. """ def __init__(self, cache_size: int): self.cache_size = cache_size self.num_free_slots = cache_size - # req_id -> cached input ids - self.cached: dict[str, set[int]] = {} - # list of [req_id, input_id] - self.freed: list[tuple[str, int]] = [] + self.num_freeable_slots = cache_size - def has_cache(self, request: Request, input_id: int) -> bool: + # mm_hash of mm_data => ids of requests that reference the mm_data + self.cached: dict[str, set[str]] = {} + + # mm_hash of mm_data => num_encoder_tokens of the mm_data + self.freeable: OrderedDict[str, int] = OrderedDict() + self.freed: list[str] = [] + + def check_and_update_cache(self, request: Request, input_id: int) -> bool: """Check if encoder output for a specific multimodal input is cached. + If the encoder output is cached, update `cached` to add the request id + to the set of request ids that reference the cached encoder output. + If the encoder output was previously not referenced by any request, + update `freeable` and `num_freeable_slots` accordingly. + Args: request: The request containing the multimodal input input_id: Index of the multimodal input within the request @@ -66,103 +86,159 @@ class EncoderCacheManager: Returns: True if the encoder output for this input is already cached """ - req_id = request.request_id - return req_id in self.cached and input_id in self.cached[req_id] - - def can_allocate(self, request: Request, input_id: int) -> bool: - """Check if there's sufficient cache space for a multimodal input. + mm_hash = request.mm_hashes[input_id] + # Not cached at all + if mm_hash not in self.cached: + return False + + # Cached but currently not referenced by any request + if not self.cached[mm_hash]: + num_tokens = self.freeable.pop(mm_hash) + self.num_freeable_slots -= num_tokens + + self.cached[mm_hash].add(request.request_id) + return True + + def can_allocate(self, request: Request, input_id: int, + encoder_compute_budget: int, + num_tokens_to_schedule: int) -> bool: + """Check if there's sufficient cache space for a multimodal input. + If there is, return True and update EncoderCacheManager state. + + If there is not enough free space in `num_free_slots` but there is + enough reclaimable space in `num_freeable_slots`, entries will be + evicted from `freeable` (their mm_hash appended to `freed`) until + enough space is available, and then this method returns True. + Older entries are evicted first. + + Returns False only if the requested number of tokens exceeds both + the free and reclaimable capacities combined. Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + request: The request containing the multimodal input. + input_id: Index of the multimodal input within the request. + encoder_compute_budget: Number of encoder tokens allowed to be + computed when this method is invoked. + num_tokens_to_schedule: Number of tokens already scheduled to be + allocated with cache space when this method is invoked. Returns: - True if there's enough free cache space to store the encoder output - for this multimodal input + True if there's enough capacity to hold the encoder output for this + input (possibly after reclaiming `freeable` entries); otherwise + False. + + Note: This method does not allocate physical memory for the encoder + output but only the state of EncoderCacheManager. """ num_tokens = request.get_num_encoder_tokens(input_id) - return num_tokens <= self.num_free_slots + + # Not enough compute budget + if num_tokens > encoder_compute_budget: + return False + + num_tokens += num_tokens_to_schedule + + # Enough free slots + if num_tokens <= self.num_free_slots: + return True + + # Not enough reclaimable slots + if num_tokens > self.num_freeable_slots: + return False + + # Not enough free slots but enough reclaimable slots + # NOTE: Eviction takes place here, but physical memory is not freed + # until model runner is notified by the scheduler output. + while num_tokens > self.num_free_slots: + mm_hash, num_free_token = self.freeable.popitem(last=False) + del self.cached[mm_hash] + self.freed.append(mm_hash) + self.num_free_slots += num_free_token + return True def allocate(self, request: Request, input_id: int) -> None: """Allocate cache space for a multimodal input's encoder output. - This method reserves cache space for storing the encoder output of - the specified multimodal input. The actual encoder output storage - happens in the model runner, but this method ensures the cache - manager tracks the allocation. - - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + This reserves cache space for storing the encoder output of the + specified multimodal input. The actual encoder output storage happens in + the model runner; this method updates the manager's bookkeeping. Note: - This method assumes can_allocate() returned True for the same - request and input_id. It will reduce available cache space. + This method assumes can_allocate() returned True for the same input. """ - req_id = request.request_id - if req_id not in self.cached: - self.cached[req_id] = set() - self.cached[req_id].add(input_id) - self.num_free_slots -= request.get_num_encoder_tokens(input_id) + + mm_hash = request.mm_hashes[input_id] + request_id = request.request_id + if mm_hash not in self.cached: + self.cached[mm_hash] = set() + + num_encoder_tokens = request.get_num_encoder_tokens(input_id) + + # NOTE: Encoder cache should always have enough space for encoder inputs + # that are scheduled since eviction takes place at can_allocate(). + assert self.num_free_slots >= num_encoder_tokens + assert self.num_freeable_slots >= num_encoder_tokens + + self.cached[mm_hash].add(request_id) + self.num_free_slots -= num_encoder_tokens + self.num_freeable_slots -= num_encoder_tokens def get_cached_input_ids(self, request: Request) -> set[int]: """Get all cached multimodal input IDs for a request. - Args: - request: The request to query - - Returns: - Set of input_ids that have cached encoder outputs for this request. - Returns empty set if no inputs are cached for this request. + Returns the set of input IDs whose `mm_hash` exists in the cache map. + This includes entries that are currently unreferenced (and thus present + in `freeable`); for such entries, freeing for this request will be a + no-op. """ - return self.cached.get(request.request_id, set()) + return { + input_id + for input_id in range(len(request.mm_hashes)) + if request.mm_hashes[input_id] in self.cached + } def free_encoder_input(self, request: Request, input_id: int) -> None: - """Free cache space for a single multimodal input's encoder output. + """Free the request's reference to the encoder input (`mm_data`) - This method is called when: - - The encoder output has been fully consumed by the decoder and is - no longer needed (e.g., in vision-language models after image - tokens are processed) - - A request is being cancelled or aborted + When the reference set for the corresponding `mm_hash` becomes empty, + the entry is appended to `freeable` and `num_freeable_slots` is + increased by the number of encoder tokens for that input. - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input to free from cache + The entry is NOT physically freed until capacity is needed (e.g., by + `can_allocate`). """ req_id = request.request_id - if req_id not in self.cached: + mm_hash = request.mm_hashes[input_id] + # The mm_hash not in cache or the req_id set is empty + if not self.cached.get(mm_hash, None): return - - self.cached[req_id].discard(input_id) - if len(self.cached[req_id]) == 0: - del self.cached[req_id] - self.num_free_slots += request.get_num_encoder_tokens(input_id) - self.freed.append((req_id, input_id)) + self.cached[mm_hash].discard(req_id) + if not self.cached[mm_hash]: + num_tokens = request.get_num_encoder_tokens(input_id) + self.freeable[mm_hash] = num_tokens + self.num_freeable_slots += num_tokens def free(self, request: Request) -> None: - """Free all cached encoder outputs for a request. + """Free all encoder input cache reference held by *request*. - This method is typically called when a request is finished, cancelled, - or aborted, and all its encoder outputs should be freed from cache. + For each cached input ID, `free_encoder_input` is invoked. + The data stays in memory until eviction is triggered by a future + attempt allocation called by 'can_allocate'. - Args: - request: The request whose encoder outputs should be freed + Typically called when a request is finished, cancelled, or aborted. """ input_ids = self.get_cached_input_ids(request).copy() for input_id in input_ids: self.free_encoder_input(request, input_id) - def get_freed_ids(self) -> list[tuple[str, int]]: + def get_freed_mm_hashes(self) -> list[str]: """Get and clear the list of recently freed encoder cache entries. - This method returns all encoder cache entries that were freed since - the last call to this method. It's used by the scheduler to notify - workers about which encoder outputs can be removed from their caches. - Returns: - List of (request_id, input_id) tuples that were freed since the - last call. The internal freed list is cleared after this call. + List of mm_hash strings that were actually evicted since the last + call to be used by the scheduler to notify workers about which + encoder outputs can be removed from their caches. The internal + list is cleared after this call. """ freed = self.freed self.freed = [] @@ -177,10 +253,31 @@ def compute_encoder_budget( """Compute the encoder cache budget based on the model and scheduler configurations. + Returns: + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. + """ + if mm_registry.supports_multimodal_inputs(model_config): + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config) + + return compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, + ) + + return compute_text_encoder_budget(scheduler_config) + + +def compute_text_encoder_budget( + scheduler_config: "SchedulerConfig") -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations for a text-only model. + Args: - model_config: Model configuration. scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. Returns: - Compute budget for encoder execution, in unit of number of tokens @@ -188,55 +285,37 @@ def compute_encoder_budget( - Space budget for encoder cache size, in unit of number of tokens in the input sequence. """ - - if not mm_registry.supports_multimodal_inputs(model_config): - return 0, 0 - - # TODO: handle encoder-decoder models once we support them. - ( - encoder_compute_budget, - encoder_cache_size, - ) = _compute_encoder_budget_multimodal( - model_config, - scheduler_config, - mm_registry, - ) - - return encoder_compute_budget, encoder_cache_size + # Currently text-only encoder-decoder models are not supported + return 0, 0 -def _compute_encoder_budget_multimodal( - model_config: "ModelConfig", +def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", - mm_registry: MultiModalRegistry, + max_tokens_by_modality: Mapping[str, int], ) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: - model_config: Model configuration. scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. + max_tokens_by_modality: The maximum number of tokens for each + non-text modality. Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. """ - max_tokens_by_modality_dict = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) - - if not max_tokens_by_modality_dict: + if not max_tokens_by_modality: logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " "not be initialized.") return 0, 0 - _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), - key=lambda item: item[1]) + max_tokens_per_mm_item = max(max_tokens_by_modality.values()) if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens): diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index a0ea4d96015a2abefc01549972e3c63931cda162..9421341f990c859161477cd4833447d94ecf50a7 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -6,7 +6,7 @@ from typing import Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, get_manager_for_kv_cache_spec) + CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.request import Request @@ -42,9 +42,10 @@ class KVCacheCoordinator(ABC): ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) - def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: + def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, + new_computed_blocks: tuple[ + list[KVCacheBlock], ...], + num_encoder_tokens: int) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -54,14 +55,22 @@ class KVCacheCoordinator(ABC): tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The number of blocks. """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + if isinstance(manager, CrossAttentionManager): + # For cross-attention, we issue a single static allocation + # of blocks based on the number of encoder input tokens. + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_encoder_tokens, []) + else: + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i]) return num_blocks_to_allocate def save_new_computed_blocks( @@ -79,8 +88,11 @@ class KVCacheCoordinator(ABC): manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> tuple[list[KVCacheBlock], ...]: + def allocate_new_blocks( + self, + request_id: str, + num_tokens: int, + num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -89,12 +101,16 @@ class KVCacheCoordinator(ABC): request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The new allocated blocks. """ return tuple( - manager.allocate_new_blocks(request_id, num_tokens) + manager.allocate_new_blocks( + request_id, num_encoder_tokens if isinstance( + manager, CrossAttentionManager) else num_tokens) for manager in self.single_type_managers) def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: @@ -103,7 +119,8 @@ class KVCacheCoordinator(ABC): Args: request: The request. - num_tokens: The total number of tokens that need to be cached + num_computed_tokens: The total number of tokens + that need to be cached (including tokens that are already cached). """ for manager in self.single_type_managers: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bfaa7ab08f5cf20b5b5986c5bce18d0afe56d58d..87a11fe58a0485418627bbbbd94847f25d84fff9 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional, overload from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -37,15 +37,35 @@ class KVCacheBlocks: tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks))) - def get_block_ids(self) -> tuple[list[int], ...]: + @overload + def get_block_ids( + self, + allow_none: Literal[False] = False, + ) -> tuple[list[int], ...]: + ... + + @overload + def get_block_ids( + self, + allow_none: Literal[True] = True, + ) -> Optional[tuple[list[int], ...]]: + ... + + def get_block_ids( + self, + allow_none: bool = False, + ) -> Optional[tuple[list[int], ...]]: """ Converts the KVCacheBlocks instance to block_ids. - + Returns: - tuple[list[int], ...]: A tuple of lists where - * the outer tuple corresponds to KV cache groups - * each inner list contains the block_ids of the blocks in that group + tuple[list[int], ...]: A tuple of lists where: + - the outer tuple corresponds to KV cache groups + - each inner list contains the block_ids of the blocks in that + group """ + if allow_none and all(len(group) == 0 for group in self.blocks): + return None return tuple([blk.block_id for blk in group] for group in self.blocks) def get_unhashed_block_ids(self) -> list[int]: @@ -168,6 +188,7 @@ class KVCacheManager: new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. @@ -234,6 +255,7 @@ class KVCacheManager: request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, + num_encoder_tokens=num_encoder_tokens, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -254,7 +276,7 @@ class KVCacheManager: new_computed_block_list) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot) + request.request_id, num_tokens_need_slot, num_encoder_tokens) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. @@ -273,7 +295,7 @@ class KVCacheManager: def free(self, request: Request) -> None: """Free the blocks allocated for the request. - We free the blocks in reverse order so that he tail blocks are evicted + We free the blocks in reverse order so that the tail blocks are evicted first when caching is enabled. Args: @@ -348,10 +370,13 @@ class KVCacheManager: """ return self.block_pool.take_events() + def get_blocks(self, request_id: str) -> KVCacheBlocks: + """Get the blocks of a request.""" + return KVCacheBlocks(self.coordinator.get_blocks(request_id)) + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" - return KVCacheBlocks( - self.coordinator.get_blocks(request_id)).get_block_ids() + return self.get_blocks(request_id).get_block_ids() def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request, if enabled.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6a62c55fb2d5fbfe6f936ee7ff3d0828b5ada2d7..590baa6208d078eb051d1c3b70b56ce0843f5d06 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -527,6 +527,7 @@ def hash_block_tokens( hash values for the same block contents. Args: + hash_function: The hash function used to compute block hash. parent_block_hash: The hash of the parent block. None if this is the first block. curr_block_token_ids: A list of token ids in the current diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index dd5052a3480b79de11b9f82263ca2f045daa9b3f..5b1de3a66ceb460ad8926fd7657024a0f63cb14d 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats - from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -61,6 +61,14 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def update_draft_token_ids( + self, + draft_token_ids: "DraftTokenIds", + ) -> None: + """Update the draft token ids for the scheduled requests.""" + raise NotImplementedError + @abstractmethod def add_request(self, request: "Request") -> None: """Add a new request to the scheduler's internal queue. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index fac07f97195bd616b7302c264caf1d23caa8cb71..b5cd6c5c8af51a4e7765461f709f99701c02bf19 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -91,7 +91,7 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - new_block_ids: list[tuple[list[int], ...]] + new_block_ids: list[Optional[tuple[list[int], ...]]] num_computed_tokens: list[int] @property @@ -143,9 +143,9 @@ class SchedulerOutput: # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. finished_req_ids: set[str] - # list of (req_id, encoder_input_index) tuples. - # Used to free the encoder cache. - free_encoder_input_ids: list[tuple[str, int]] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] # Dict of request ids to their index within the batch # for filling the next token bitmask diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0cab63740d15f238736d9ffc1135a8b0254771fe..a9038f9bc32897380cab42e58695eaa5802cb83c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,18 +19,18 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.core.sched.request_queue import (SchedulingPolicy, create_request_queue) -from vllm.v1.core.sched.utils import check_stop +from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager @@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface): self.parallel_config = vllm_config.parallel_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned @@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface): assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " "with KV connectors") + assert not self.is_encoder_decoder, ( + "Encoder-decoder models are not currently supported " + "with KV connectors") self.connector = KVConnectorFactory.create_connector( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) @@ -141,7 +145,7 @@ class Scheduler(SchedulerInterface): cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config - self.speculative_config = speculative_config + # self.speculative_config = speculative_config self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 @@ -179,20 +183,12 @@ class Scheduler(SchedulerInterface): scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] - # NOTE: structured_output_request_ids maps - # a request's (request that uses structured output) - # request_id to the running request index. - # This will helps us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. - structured_output_request_ids: dict[str, int] = {} - - req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} - encoder_budget = self.max_num_encoder_input_tokens + encoder_compute_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: dict[str, list[int]] = {} @@ -221,12 +217,13 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, num_new_tokens, - encoder_budget) + encoder_compute_budget) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -258,10 +255,13 @@ class Scheduler(SchedulerInterface): key=lambda r: (r.priority, r.arrival_time), ) self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) else: preempted_req = self.running.pop() self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.log_stats: @@ -284,14 +284,7 @@ class Scheduler(SchedulerInterface): # Schedule the request. scheduled_running_reqs.append(request) - if request.use_structured_output: - # PERF: in case of chunked prefill, - # request might not include any new tokens. - # Therefore, we might introduce some additional - # cycle to fill in the bitmask, which could be a big no-op. - structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = ( - new_blocks.get_block_ids()) + req_to_new_blocks[request.request_id] = new_blocks num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -314,7 +307,7 @@ class Scheduler(SchedulerInterface): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() @@ -398,7 +391,7 @@ class Scheduler(SchedulerInterface): num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: @@ -429,10 +422,10 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget + new_encoder_compute_budget ) = self._try_schedule_encoder_inputs( request, num_computed_tokens, num_new_tokens, - encoder_budget) + encoder_compute_budget) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -446,6 +439,22 @@ class Scheduler(SchedulerInterface): == 0 else self.num_lookahead_tokens) + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + assert ("whisper" + in self.vllm_config.model_config.model.lower()), ( + "Whisper is the only supported " + "encoder-decoder model.") + num_encoder_tokens = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len( + self.vllm_config.model_config) + else: + num_encoder_tokens = 0 + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, @@ -453,6 +462,7 @@ class Scheduler(SchedulerInterface): new_computed_blocks, num_lookahead_tokens=effective_lookahead_tokens, delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, ) if new_blocks is None: @@ -480,9 +490,6 @@ class Scheduler(SchedulerInterface): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue - if request.use_structured_output: - structured_output_request_ids[request.request_id] = ( - req_index) req_index += 1 self.running.append(request) if self.log_stats: @@ -498,8 +505,8 @@ class Scheduler(SchedulerInterface): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = ( - self.kv_cache_manager.get_block_ids(request.request_id)) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -514,7 +521,7 @@ class Scheduler(SchedulerInterface): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: @@ -541,15 +548,10 @@ class Scheduler(SchedulerInterface): self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) - grammar_bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) # Construct the scheduler output. new_reqs_data = [ - NewRequestData.from_request(req, - req_to_new_block_ids[req.request_id]) + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -557,8 +559,11 @@ class Scheduler(SchedulerInterface): scheduled_resumed_reqs, num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_block_ids, + req_to_new_blocks, ) + structured_output_request_ids, grammar_bitmask = ( + self.get_grammar_bitmask(self.running, + scheduled_spec_decode_tokens)) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -572,7 +577,8 @@ class Scheduler(SchedulerInterface): # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -585,7 +591,19 @@ class Scheduler(SchedulerInterface): meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta + # collect KV cache events from KV cache manager events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events if events: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) @@ -1052,11 +1070,11 @@ class Scheduler(SchedulerInterface): resumed_reqs: list[Request], num_scheduled_tokens: dict[str, int], spec_decode_tokens: dict[str, list[int]], - req_to_new_block_ids: dict[str, tuple[list[int], ...]], + req_to_new_blocks: dict[str, KVCacheBlocks], ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[tuple[list[int], ...]] = [] + new_block_ids: list[Optional[tuple[list[int], ...]]] = [] num_computed_tokens: list[int] = [] use_connector = self.connector is not None @@ -1079,7 +1097,8 @@ class Scheduler(SchedulerInterface): # out of bounds errors. TODO: Remove this once the KVConnector # is updated to handle token IDs properly. new_token_ids.append([]) - new_block_ids.append(req_to_new_block_ids[req_id]) + new_block_ids.append( + req_to_new_blocks[req_id].get_block_ids(allow_none=True)) num_computed_tokens.append(req.num_computed_tokens) # Because resumed_reqs is usually empty, it is more efficient to do # in-place appending so that we don't need to allocate a new list. @@ -1099,7 +1118,7 @@ class Scheduler(SchedulerInterface): request: Request, num_computed_tokens: int, num_new_tokens: int, - encoder_budget: int, + encoder_compute_budget: int, ) -> tuple[list[int], int, int]: """ Determine which encoder inputs need to be scheduled in the current step, @@ -1121,11 +1140,17 @@ class Scheduler(SchedulerInterface): blocks and externally cached blocks (via KVConnector). """ if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_budget + return [], num_new_tokens, encoder_compute_budget encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions assert mm_positions is not None assert len(mm_positions) > 0 + + # NOTE: since scheduler operates on the request level (possibly with + # multiple encoder inputs per request), we need to create temporary + # trackers for accounting at the encoder input level. + mm_hashes_to_schedule = set() + num_tokens_to_schedule = 0 for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1136,13 +1161,34 @@ class Scheduler(SchedulerInterface): if start_pos >= num_computed_tokens + num_new_tokens: # The encoder input is not needed in this step. break - if start_pos + num_encoder_tokens <= num_computed_tokens: + + if self.is_encoder_decoder and num_computed_tokens > 0: + assert start_pos == 0, ( + "Encoder input should be processed at the beginning of " + "the sequence when encoder-decoder models are used.") + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: # The encoder input is already computed and stored # in the decoder's KV cache. continue - if self.encoder_cache_manager.has_cache(request, i): - # The encoder input is already computed and cached. + # The same encoder input has already been scheduled in the current + # step. + if request.mm_hashes[i] in mm_hashes_to_schedule: + continue + + if self.encoder_cache_manager.check_and_update_cache(request, i): + # The encoder input is already computed and cached from a + # previous step. continue # If no encoder input chunking is allowed, we do not want to @@ -1155,8 +1201,9 @@ class Scheduler(SchedulerInterface): num_new_tokens = start_pos - num_computed_tokens break - if (not self.encoder_cache_manager.can_allocate(request, i) - or num_encoder_tokens > encoder_budget): + if not self.encoder_cache_manager.can_allocate( + request, i, encoder_compute_budget, + num_tokens_to_schedule): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -1173,9 +1220,46 @@ class Scheduler(SchedulerInterface): num_new_tokens = 0 break - encoder_budget -= num_encoder_tokens + num_tokens_to_schedule += num_encoder_tokens + encoder_compute_budget -= num_encoder_tokens + mm_hashes_to_schedule.add(request.mm_hashes[i]) encoder_inputs_to_schedule.append(i) - return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + + return ( + encoder_inputs_to_schedule, + num_new_tokens, + encoder_compute_budget, + ) + + def get_grammar_bitmask( + self, + requests: list[Request], + scheduled_spec_decode_tokens: dict[str, list[int]], + ): + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to its index in the batch. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + for i, req in enumerate(requests): + if req.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[req.request_id] = i + + if not structured_output_request_ids: + bitmask = None + else: + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + return structured_output_request_ids, bitmask def update_from_output( self, @@ -1183,7 +1267,6 @@ class Scheduler(SchedulerInterface): model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids - spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -1268,20 +1351,9 @@ class Scheduler(SchedulerInterface): request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) - # spec_token_ids comes from the model runner output if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] - # Add newly generated spec token ids to the request. - if spec_token_ids is not None: - if self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] - # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ @@ -1308,9 +1380,7 @@ class Scheduler(SchedulerInterface): # Remove the stopped requests from the running and waiting queues. if stopped_running_reqs: - self.running = [ - req for req in self.running if req not in stopped_running_reqs - ] + self.running = remove_all(self.running, stopped_running_reqs) if stopped_preempted_reqs: # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) @@ -1340,10 +1410,13 @@ class Scheduler(SchedulerInterface): finished_requests=finished_set) finished_req_ids.clear() - if engine_core_outputs: + if (stats := self.make_stats(spec_decoding_stats)) is not None: # Return stats to only one of the front-ends. - next(iter(engine_core_outputs.values())).scheduler_stats = ( - self.make_stats(spec_decoding_stats)) + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats return engine_core_outputs @@ -1386,6 +1459,30 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager.free_encoder_input( request, input_id) + def update_draft_token_ids( + self, + draft_token_ids: DraftTokenIds, + ) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if not spec_token_ids: + # NOTE(woosuk): request.spec_token_ids should be updated. + request.spec_token_ids.clear() + elif self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids) + else: + request.spec_token_ids = spec_token_ids + def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting) @@ -1412,7 +1509,7 @@ class Scheduler(SchedulerInterface): else: request_ids = set(request_ids) - running_requests_to_remove = [] + running_requests_to_remove = set() waiting_requests_to_remove = [] valid_requests = [] @@ -1425,13 +1522,13 @@ class Scheduler(SchedulerInterface): valid_requests.append(request) if request.status == RequestStatus.RUNNING: - running_requests_to_remove.append(request) + running_requests_to_remove.add(request) else: waiting_requests_to_remove.append(request) # Remove all requests from queues at once for better efficiency - for request in running_requests_to_remove: - self.running.remove(request) + if running_requests_to_remove: + self.running = remove_all(self.running, running_requests_to_remove) if waiting_requests_to_remove: self.waiting.remove_requests(waiting_requests_to_remove) @@ -1546,7 +1643,7 @@ class Scheduler(SchedulerInterface): # Now that the blocks are ready, actually cache them. (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size - # Handle the case where num request tokens less then one block. + # Handle the case where num request tokens less than one block. num_computed_tokens = min(num_computed_tokens, request.num_tokens) if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 42ec95091f962739aad0b04b5a32e70ce11847e3..42d3e5c68b4c81c1dc020311d210591aac5979b5 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from typing import Optional import torch @@ -7,6 +8,38 @@ import torch from vllm.v1.request import Request, RequestStatus +def remove_all(lst: list, items_to_remove: set) -> list: + """Remove all items from a list that are in the items_to_remove set. + + This method optimizes for the common case of removing a single item, + falling back to list comprehension for multiple items. + + Args: + lst: The list to remove items from + items_to_remove: Set of items to remove + + Returns: + Either the modified original list (for single item removal) or + a new list (for multiple item removal). Callers should use the + returned value. + + Note: + For single item removal, this modifies the original list in-place + and returns it. For multiple items, it creates and returns a new list. + """ + if not items_to_remove: + return lst + + if len(items_to_remove) == 1: + # Fast path for single item removal (most common case) + item = next(iter(items_to_remove)) + with contextlib.suppress(ValueError): + lst.remove(item) + return lst + # For multiple items, use list comprehension + return [item for item in lst if item not in items_to_remove] + + def check_stop(request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None) -> bool: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 82e0292522b9ace5d0123368e4827ac775f6f771..f6affb3dab66ff7085c894092a06b06955b1d96f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -8,8 +8,9 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + CrossAttentionSpec, FullAttentionSpec, + KVCacheSpec, MambaSpec, + SlidingWindowSpec) from vllm.v1.request import Request @@ -46,7 +47,7 @@ class SingleTypeKVCacheManager(ABC): # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. + # data for preempted ones. self.num_cached_block: dict[str, int] = {} self.kv_cache_group_id = kv_cache_group_id @@ -552,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager): return new_blocks +class CrossAttentionManager(SingleTypeKVCacheManager): + """Manager for cross-attention KV cache in encoder-decoder models.""" + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so `new_computed_blocks` should always be empty. + assert len(new_computed_blocks) == 0 + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so this method is not relevant. + raise ValueError("Should not be called as prefix caching is disabled.") + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + # Cross-attention blocks contain request-specific encoder states + # and are not shared between different requests + return 0 + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> tuple[list[KVCacheBlock], ...]: + assert isinstance(kv_cache_spec, CrossAttentionSpec), ( + "CrossAttentionManager can only be used for cross-attention groups" + ) + # Cross-attention does not benefit from prefix caching since: + # 1. Encoder states are unique per request (different audio/image + # inputs) + # 2. Encoder states are computed once per request, not incrementally + # 3. No reusable prefix exists between different multimodal inputs + # Return empty blocks to indicate no cache hits + raise NotImplementedError( + "CrossAttentionManager does not support caching") + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Cross-attention blocks represent encoder states which are needed + # for the entire decoding process, so no blocks should be skipped + pass + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, } diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 02e65820b7c00996d53c0c0ad8a4d977a0f0b964..d2db7dcb3f0910a5369de1f16a4156f7c0a66627 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -11,7 +11,8 @@ logger = init_logger(__name__) class CudagraphDispatcher: """ - Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs. + Runtime cudagraph dispatcher to dispatch keys for multiple set of + cudagraphs. The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one for FULL cudagraph runtime mode. The keys are initialized depending on @@ -21,7 +22,7 @@ class CudagraphDispatcher: At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) - based on the input key. After dispatching (commuicate via forward context), + based on the input key. After dispatching (communicate via forward context), the cudagraph wrappers will trust the dispatch key to do either capturing or replaying (if mode matched), or pass through to the underlying runnable without cudagraph (if mode no match or mode is NONE). diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index f7ec982db41b40319a99ffa023c9291f59f3e84f..5d8959a3cd3feb4565b28b4f6eab0f1bde0f0f08 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,14 +3,13 @@ import enum import time -from collections.abc import Sequence from typing import Any, Optional, Union import msgspec import torch from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats @@ -48,9 +47,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] - mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]] - mm_hashes: Optional[list[str]] - mm_placeholders: Optional[list[PlaceholderRange]] + mm_features: Optional[list[MultiModalFeatureSpec]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] eos_token_id: Optional[int] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 664fec31a4da5fb7d921cee944600ff62c99e6d7..2a9fa1fd9172c023ac824277199d55ab4965187f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import os +import socket import time from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy from typing import Any, Optional, Union import numpy as np +import torch import vllm.envs as envs from vllm.config import ModelConfig, VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient +from vllm.entrypoints.utils import _validate_truncation_size from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE from vllm.inputs import PromptType from vllm.inputs.preprocess import InputPreprocessor @@ -144,6 +148,26 @@ class AsyncLLM(EngineClient): except RuntimeError: pass + if envs.VLLM_TORCH_PROFILER_DIR: + logger.info( + "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 + envs.VLLM_TORCH_PROFILER_DIR) + worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + envs.VLLM_TORCH_PROFILER_DIR, + worker_name=worker_name, + use_gzip=True)) + else: + logger.info( + "Torch profiler disabled. AsyncLLM CPU traces will not be collected." # noqa: E501 + ) + self.profiler = None + @classmethod @deprecate_kwargs( "disable_log_requests", @@ -312,12 +336,28 @@ class AsyncLLM(EngineClient): returning the RequestOutput back to the caller. """ + if (self.vllm_config.cache_config.kv_sharing_fast_prefill + and sampling_params.prompt_logprobs): + raise ValueError( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, please disable it when the requests need " + "prompt logprobs") + try: # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + tokenization_kwargs: dict[str, Any] = {} + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens + + _validate_truncation_size( + self.model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs, + ) + q = await self.add_request( request_id, prompt, @@ -325,6 +365,7 @@ class AsyncLLM(EngineClient): lora_request=lora_request, trace_headers=trace_headers, priority=priority, + tokenization_kwargs=tokenization_kwargs, data_parallel_rank=data_parallel_rank, ) @@ -451,6 +492,7 @@ class AsyncLLM(EngineClient): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + truncate_prompt_tokens: Optional[int] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ @@ -473,6 +515,14 @@ class AsyncLLM(EngineClient): # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + if tokenization_kwargs is None: + tokenization_kwargs = dict[str, Any]() + _validate_truncation_size( + self.model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs, + ) + q = await self.add_request( request_id, prompt, @@ -562,14 +612,19 @@ class AsyncLLM(EngineClient): raise self.dead_error async def start_profile(self) -> None: - await self.engine_core.profile_async(True) + coros = [self.engine_core.profile_async(True)] + if self.profiler is not None: + coros.append(asyncio.to_thread(self.profiler.start)) + await asyncio.gather(*coros) async def stop_profile(self) -> None: - await self.engine_core.profile_async(False) + coros = [self.engine_core.profile_async(False)] + if self.profiler is not None: + coros.append(asyncio.to_thread(self.profiler.stop)) + await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index aaa1800cb3f416bec493c08637d5cfb633ceb286..930b308c75216ca26e07d915ebe82b3e223432bc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -24,6 +24,7 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -40,8 +41,8 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer -from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses +from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses, + get_device_indices) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -128,21 +129,23 @@ class EngineCore: > 1, log_stats=self.log_stats, ) + self.use_spec_decode = vllm_config.speculative_config is not None - self.mm_input_cache_server = MultiModalInputCacheServer( - vllm_config.model_config, MULTIMODAL_REGISTRY) + self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + self.mm_receiver_cache = receiver_cache_from_config( + vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], - SchedulerOutput]]] = None + self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput], + SchedulerOutput]]] = None if self.batch_queue_size > 1: logger.info("Batch queue is enabled with size %d", self.batch_queue_size) - self.batch_queue = queue.Queue(self.batch_queue_size) + self.batch_queue = deque(maxlen=self.batch_queue_size) self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None @@ -297,6 +300,13 @@ class EngineCore: return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) + def post_step(self, model_executed: bool) -> None: + if self.use_spec_decode and model_executed: + # Take the draft token ids. + draft_token_ids = self.model_executor.take_draft_token_ids() + if draft_token_ids is not None: + self.scheduler.update_draft_token_ids(draft_token_ids) + def step_with_batch_queue( self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. @@ -312,41 +322,43 @@ class EngineCore: batch in the job queue is finished. 3. Update the scheduler from the output. """ - assert self.batch_queue is not None + batch_queue = self.batch_queue + assert batch_queue is not None - engine_core_outputs = None - scheduler_output = None # Try to schedule a new batch if the batch queue is not full, but # the scheduler may return an empty batch if all requests are scheduled. # Note that this is not blocking. - if not self.batch_queue.full(): - scheduler_output = self.scheduler.schedule() - if scheduler_output.total_num_scheduled_tokens > 0: - future = self.model_executor.execute_model(scheduler_output) - self.batch_queue.put_nowait( - (future, scheduler_output)) # type: ignore - - scheduled_batch = (scheduler_output is not None - and scheduler_output.total_num_scheduled_tokens > 0) - - # If no more requests can be scheduled and the job queue is not empty, - # block until the first batch in the job queue is finished. - # TODO(comaniac): Ideally we should peek the first batch in the - # job queue to check if it's finished before scheduling a new batch, - # but peeking the first element in a queue is not thread-safe, - # so we need more work. - if not scheduled_batch and not self.batch_queue.empty(): - future, scheduler_output = self.batch_queue.get_nowait() + assert len(batch_queue) < self.batch_queue_size - # Blocking until the first result is available. - model_output = self.execute_model_with_error_logging( - lambda _: future.result(), scheduler_output) + model_executed = False + if self.scheduler.has_requests(): + scheduler_output = self.scheduler.schedule() + future = self.model_executor.execute_model(scheduler_output) + batch_queue.appendleft( + (future, scheduler_output)) # type: ignore[arg-type] + + model_executed = scheduler_output.total_num_scheduled_tokens > 0 + if model_executed and len(batch_queue) < self.batch_queue_size \ + and not batch_queue[-1][0].done(): + # Don't block on next worker response unless the queue is full + # or there are no more requests to schedule. + return None, True + + elif not batch_queue: + # Queue is empty. We should not reach here since this method should + # only be called when the scheduler contains requests or the queue + # is non-empty. + return None, False + + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + model_output = self.execute_model_with_error_logging( + lambda _: future.result(), scheduler_output) - self.batch_queue.task_done() - engine_core_outputs = (self.scheduler.update_from_output( - scheduler_output, model_output)) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output) - return engine_core_outputs, scheduled_batch + return engine_core_outputs, model_executed def shutdown(self): self.structured_output_manager.clear_backend() @@ -365,7 +377,8 @@ class EngineCore: logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_input_cache_server.reset() + if self.mm_receiver_cache is not None: + self.mm_receiver_cache.clear_cache() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -380,7 +393,7 @@ class EngineCore: return self.model_executor.is_sleeping def execute_dummy_batch(self): - self.model_executor.collective_rpc("execute_dummy_batch") + self.model_executor.execute_dummy_batch() def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) @@ -426,14 +439,13 @@ class EngineCore: This function could be directly used in input processing thread to allow request initialization running in parallel with Model forward """ - if request.mm_hashes is not None: - assert request.mm_kwargs is not None - - # Note on thread safety: no race condition. - # `mm_input_cache_server` is reset at the end of LLMEngine init, - # and will only accessed in the input processing thread afterwards. - request.mm_kwargs = self.mm_input_cache_server.get_and_update( - request.mm_kwargs, request.mm_hashes) + # Note on thread safety: no race condition. + # `mm_receiver_cache` is reset at the end of LLMEngine init, + # and will only accessed in the input processing thread afterwards. + if self.mm_receiver_cache is not None and request.mm_features: + request.mm_features = ( + self.mm_receiver_cache.get_and_update_features( + request.mm_features)) req = Request.from_engine_core_request(request, self.request_block_hasher) @@ -726,7 +738,8 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not self.scheduler.has_requests(): + while not self.engines_running and not self.scheduler.has_requests() \ + and not self.batch_queue: if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -749,6 +762,8 @@ class EngineCoreProc(EngineCore): # Put EngineCoreOutputs into the output queue. for output in (outputs.items() if outputs else ()): self.output_queue.put_nowait(output) + # Post-step hook. + self.post_step(model_executed) return model_executed @@ -1159,22 +1174,30 @@ class DPEngineCoreActor(DPEngineCoreProc): # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501 # and get_accelerator_ids_for_accelerator_resource() in worker.py # of ray. - self._set_cuda_visible_devices(vllm_config, local_dp_rank) + self._set_visible_devices(vllm_config, local_dp_rank) super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int): from vllm.platforms import current_platform - device_control_env_var = current_platform.device_control_env_var + if current_platform.is_xpu(): + pass + else: + device_control_env_var = current_platform.device_control_env_var + self._set_cuda_visible_devices(vllm_config, local_dp_rank, + device_control_env_var) + + def _set_cuda_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int, + device_control_env_var: str): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * - world_size, (local_dp_rank + 1) * world_size)) + value = get_device_indices(device_control_env_var, local_dp_rank, + world_size) + os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 079dd9a7d38d17445624631332c5ec0868d6372b..65f7abc97110c61979e7c1ab655425e1a0266b57 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1190,21 +1190,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient): await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) - async def _send_reconfig_message( - self, reconfig_request: ReconfigureDistributedRequest, - engine: EngineIdentity) -> asyncio.Future: - """Send reconfiguration message and return the result future without - waiting for completion.""" - call_id = uuid.uuid1().int >> 64 - future = asyncio.get_running_loop().create_future() - self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, "reinitialize_distributed", - (reconfig_request, )))) - await self._send_input_message(message, engine, reconfig_request) - self._ensure_output_queue_task() - return future - async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: """Scale elastic EP data parallel size""" cur_data_parallel_size = len(self.core_engines) @@ -1214,7 +1199,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient): f"different from cur_data_parallel_size {cur_data_parallel_size}") assert self.vllm_config.parallel_config.data_parallel_backend == \ - "ray", ("Only ray DP backend supports scaling elastic EP") + "ray", "Only ray DP backend supports scaling elastic EP" scale_up = new_data_parallel_size > cur_data_parallel_size @@ -1246,9 +1231,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): data_parallel_master_ip, new_data_parallel_master_port=self.vllm_config.parallel_config. data_parallel_master_port) - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) logger.info("All reconfigure messages sent, starting engine creation") @@ -1318,9 +1304,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) for _ in range(new_data_parallel_size, cur_data_parallel_size): self.core_engines.pop() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5a00a930951cc149435c13c1c5c866de48f8bb50..7130f666ef19fb5509aa7e41eafc38b147b49d3d 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -271,8 +271,7 @@ class LLMEngine: self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py deleted file mode 100644 index aa7dc62fd4acbaea3bc7c68eb00a48557746048d..0000000000000000000000000000000000000000 --- a/vllm/v1/engine/mm_input_cache.py +++ /dev/null @@ -1,121 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional - -from vllm.multimodal import MultiModalRegistry -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import MultiModalKwargsItem -from vllm.utils import is_list_of - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -# The idea of multimodal input caching is based on having a client and -# a server, where the client executes in the frontend process (=P0) and the -# server in the core process (=P1). -# -# -- P0: -# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of -# each input multi-modal item (e.g. image), -# - BaseMultiModalProcessor processes the input items into `mm_kwargs`, -# which are MultiModalKwargsItem instances that each correspond to an -# input multi-modal item. -# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding -# `mm_hash` for each item. It stores the `mm_hash` as keys and the size -# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking -# up additional memory in P0. -# - The `mm_hash` is always sent to P1. -# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached -# in MultiModalInputCacheServer. -# -# -- P1: -# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0), -# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`. -# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0), -# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`. -# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to -# the engine for model execution. -# -# Both Client and Server must perform cache update and eviction based on the -# same item size. This ensures that the keys of MultiModalInputCacheClient -# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0 -# whether a key is cached in MultiModalInputCacheServer by querying -# MultiModalInputCacheClient without having to communicate with P1. - - -class MultiModalInputCacheClient: - """Used by P0 to check whether multi-modal kwargs are cached in P1.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalCacheItemMetadata, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[MultiModalKwargsItem], - mm_hashes: list[str], - ) -> list[Optional[MultiModalKwargsItem]]: - if not self.enabled: - return list(mm_kwargs) - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[Optional[MultiModalKwargsItem]]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if self.mm_cache.get(mm_hash) is not None: - out_mm_items.append(None) - else: - self.mm_cache[mm_hash] = \ - MultiModalCacheItemMetadata.wraps(mm_item) - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() - - -class MultiModalInputCacheServer: - """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalKwargsItem, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[Optional[MultiModalKwargsItem]], - mm_hashes: list[str], - ) -> list[MultiModalKwargsItem]: - if not self.enabled: - mm_kwargs_lst = list(mm_kwargs) - assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem) - return mm_kwargs_lst - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[MultiModalKwargsItem]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if mm_item is None: - out_mm_items.append(self.mm_cache[mm_hash]) - else: - self.mm_cache[mm_hash] = mm_item - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index c6a23cdbf65aea2a814f46b4ad520fa4d3271745..1aa117ded4ed8ea787d4b927ec4f34e6ca3825e0 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -11,17 +11,18 @@ from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.cache import processor_cache_from_config +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from vllm.utils import is_list_of from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) +from vllm.v1.structured_output.backend_lm_format_enforcer import ( + validate_structured_output_request_lm_format_enforcer) from vllm.v1.structured_output.backend_outlines import ( validate_structured_output_request_outlines) from vllm.v1.structured_output.backend_xgrammar import ( @@ -46,16 +47,17 @@ class Processor: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) - self.mm_input_cache_client = MultiModalInputCacheClient( - self.model_config, mm_registry) + self.mm_registry = mm_registry + self.mm_processor_cache = processor_cache_from_config( + vllm_config, mm_registry) - @property - def mm_registry(self): - return self.input_preprocessor.mm_registry + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) def _validate_logprobs( self, @@ -148,6 +150,49 @@ class Processor: self._validate_sampling_params(params, lora_request) self._validate_supported_sampling_params(params) + def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: + """ + Validate that user-provided multi_modal_uuids align with + multi_modal_data in the incoming request prompt(s). + Only checks lengths; `None` entries are allowed and will be + auto-hashed downstream. + """ + + def _validate_single_prompt(single_prompt: Union[dict, str]) -> None: + if not isinstance(single_prompt, dict): + return + mm_data = single_prompt.get("multi_modal_data") + mm_uuids = single_prompt.get("multi_modal_uuids") + if not mm_data or not mm_uuids: + return + + for modality, items in mm_data.items(): + if modality in mm_uuids: + data_len = len(items) if isinstance(items, list) else 1 + uuid_len = len(mm_uuids[modality]) if isinstance( + mm_uuids[modality], list) else 1 + if uuid_len != data_len: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' " + "must have same length as data: got " + f"{uuid_len} uuids vs " + f"{data_len} items.") + else: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' must " + "be provided if multi_modal_data is provided.") + + # Handle explicit encoder/decoder prompts or singleton prompt + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + enc = prompt.get("encoder_prompt") + dec = prompt.get("decoder_prompt") + if enc is not None: + _validate_single_prompt(enc) + if dec is not None: + _validate_single_prompt(dec) + else: + _validate_single_prompt(prompt) # type: ignore[arg-type] + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -201,6 +246,9 @@ class Processor: elif engine_level_backend == "outlines": # outlines backend validate_structured_output_request_outlines(params) + elif engine_level_backend == "lm-format-enforcer": + # lm format enforcer backend + validate_structured_output_request_lm_format_enforcer(params) else: # NOTE: engine_level_backend must be "auto" here, because we have # checked supported_backends above. @@ -220,6 +268,41 @@ class Processor: # Remember that this backend was set automatically params.guided_decoding.backend_was_auto = True + def _maybe_build_mm_hash_overrides( + self, + request_id: str, + prompt: PromptType, + ) -> Optional[dict[str, list[str]]]: + """Build per-item multimodal hash overrides when enabled. In this case, + multimodal data items are identified by their request id, modality and + index rather than their content. + + Returns a dictionary of modality -> list[str] of overrides, or None if + disabled or no multimodal data is present. + """ + + def _extract_mm_data(p: PromptType): + if isinstance(p, dict) and "encoder_prompt" in p: + enc = p.get("encoder_prompt") + if isinstance(enc, dict): + return enc.get("multi_modal_data") + return None + if isinstance(p, dict): + return p.get("multi_modal_data") + return None + + mm_data = _extract_mm_data(prompt) + if not mm_data: + return None + + overrides: dict[str, list[str]] = {} + for modality, data in mm_data.items(): + n = len(data) if isinstance(data, list) else 1 + overrides[modality] = [ + f"{request_id}-{modality}-{i}" for i in range(n) + ] + return overrides + def process_inputs( self, request_id: str, @@ -249,17 +332,37 @@ class Processor: if arrival_time is None: arrival_time = time.time() + # Optionally generate multimodal hash overrides to avoid hashing + # multimodal data items by their content as their identifiers. + + # NOTE: when users explicitly turn off BOTH prefix caching and input + # processing caching, no multimodal features or embeddings will be + # reused across requests, therefore identifying multimodal data items + # by their content is no longer necessary, and we create uuids with + # request id-modality-index as multimodal hash overrides. + if (self.model_config.multimodal_config and + self.model_config.multimodal_config.mm_processor_cache_gb == 0 + and not self.cache_config.enable_prefix_caching): + mm_hash_overrides = self._maybe_build_mm_hash_overrides( + request_id, prompt) + else: + # Otherwise, use user-provided uuids as multimodal hash overrides + # if provided. + self._validate_multi_modal_uuids(prompt) + if isinstance(prompt, dict): + mm_hash_overrides = prompt.get("multi_modal_uuids") + else: + mm_hash_overrides = None + # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. - return_mm_hashes = (self.model_config.processor_return_mm_hashes - or bool(self.cache_config.enable_prefix_caching)) processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -267,6 +370,7 @@ class Processor: params=params, processed_inputs=processed_inputs, ) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) self._validate_model_inputs(processed_inputs, lora_request) @@ -296,47 +400,31 @@ class Processor: pooling_params = params.clone() # Multimodal related. - sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None - sorted_mm_positions: Optional[list[PlaceholderRange]] = None - sorted_mm_hashes: Optional[list[str]] = None + mm_features: Optional[list[MultiModalFeatureSpec]] = None + if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_positions = decoder_inputs["mm_placeholders"] - decoder_mm_hashes = decoder_inputs.get("mm_hashes") + decoder_mm_hashes = decoder_inputs["mm_hashes"] # Merge and flatten multimodal placeholders, hashes and inputs # from dictionaries to lists, and sort them by each item's position # in the input sequence. sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - orig_sorted_mm_inputs = [ - decoder_mm_inputs.get_item(modality, idx) - for modality, idx in sorted_mm_idxs - ] - sorted_mm_positions = [ - decoder_mm_positions[modality][idx] - for modality, idx in sorted_mm_idxs - ] - sorted_mm_hashes = None if decoder_mm_hashes is None else [ - decoder_mm_hashes[modality][idx] - for modality, idx in sorted_mm_idxs - ] - - if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_input_cache_client.get_and_update( - orig_sorted_mm_inputs, - sorted_mm_hashes, - ) - else: - assert is_list_of(orig_sorted_mm_inputs, MultiModalKwargsItem) - sorted_mm_inputs = orig_sorted_mm_inputs + mm_features = [] + for modality, idx in sorted_mm_idxs: + mm_features.append( + MultiModalFeatureSpec( + data=decoder_mm_inputs[modality][idx], + modality=modality, + identifier=decoder_mm_hashes[modality][idx], + mm_position=decoder_mm_positions[modality][idx])) return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], - mm_kwargs=sorted_mm_inputs, - mm_hashes=sorted_mm_hashes, - mm_placeholders=sorted_mm_positions, + mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, eos_token_id=eos_token_id, @@ -382,7 +470,19 @@ class Processor: else: tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: + + # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while + # self.model_config.get_vocab_size() is the model’s vocab size. + # For Qwen3 models, the language model has extra tokens that do + # not exist in the tokenizer, and vice versa for multimodal + # placeholder tokens in some multimodal models. + # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501 + # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501 + + # Here we take the max of the two to determine if a token id is + # truly out-of-vocabulary. + if max_input_id > max(tokenizer.max_token_id, + self.model_config.get_vocab_size() - 1): raise ValueError( f"Token id {max_input_id} is out of vocabulary") @@ -397,7 +497,7 @@ class Processor: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( @@ -418,3 +518,6 @@ class Processor: # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def clear_cache(self) -> None: + self.input_preprocessor.clear_cache() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 770aa7d9dcc8a1d6571dc972336f7cf51b3f7f26..56ef8477d267ac03a1ed6ffec239448653564b1c 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -71,7 +71,7 @@ class EngineHandshakeMetadata: connect to. """ addresses: EngineZmqAddresses - parallel_config: dict[str, Union[int, str]] + parallel_config: dict[str, Union[int, str, list[int]]] class CoreEngineProcManager: @@ -164,19 +164,33 @@ def set_device_control_env_var(vllm_config: VllmConfig, """ world_size = vllm_config.parallel_config.world_size evar = current_platform.device_control_env_var + + value = get_device_indices(evar, local_dp_rank, world_size) + with patch.dict(os.environ, values=((evar, value), )): + yield + + +def get_device_indices(device_control_env_var: str, local_dp_rank: int, + world_size: int): + """ + Returns a comma-separated string of device indices for the specified + data parallel rank. + + For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, + this will select devices 2 and 3 for local_dp_rank=1. + """ try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)) except IndexError as e: - raise Exception(f"Error setting {evar}: " + raise Exception(f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " "base value: " - f"\"{os.getenv(evar)}\"") from e - with patch.dict(os.environ, values=((evar, value), )): - yield + f"\"{os.getenv(device_control_env_var)}\"") from e + return value class CoreEngineActorManager: @@ -254,6 +268,19 @@ class CoreEngineActorManager: dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count + + # Ray XPU known issue: dpctl initializes the GPU runtime early, so + # setting device env vars in Ray actor's initialization method + # will not affect device selection. See: + # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 + if current_platform.is_xpu(): + device_evar = current_platform.device_control_env_var + device_indices = get_device_indices(device_evar, local_index, + world_size) + actor_env_vars = self.env_vars_dict.copy() + actor_env_vars[device_evar] = device_indices + runtime_env = RuntimeEnv(env_vars=actor_env_vars) + actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, @@ -798,6 +825,8 @@ def wait_for_engine_startup( parallel_config.data_parallel_master_ip, "data_parallel_master_port": parallel_config.data_parallel_master_port, + "_data_parallel_master_port_list": + parallel_config._data_parallel_master_port_list, "data_parallel_size": parallel_config.data_parallel_size, })) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 50b9634a49e1b638ac278bab8b0a3501105b0468..68408a0b8a3d58316dadd5908af631b2c9a80a58 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -13,8 +13,9 @@ from vllm.executor.uniproc_executor import ( # noqa ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) from vllm.executor.uniproc_executor import ( # noqa UniProcExecutor as UniProcExecutorV0) +from vllm.utils import resolve_obj_by_qualname from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput FailureCallback = Callable[[], None] @@ -50,6 +51,13 @@ class Executor(ExecutorBase): # TODO: make v1 scheduling deterministic # to support external launcher executor_class = ExecutorWithExternalLauncher + elif isinstance(distributed_executor_backend, str): + executor_class = resolve_obj_by_qualname( + distributed_executor_backend) + if not issubclass(executor_class, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {executor_class}.") else: raise ValueError("Unknown distributed executor backend: " f"{distributed_executor_backend}") @@ -73,12 +81,10 @@ class Executor(ExecutorBase): pass def determine_available_memory(self) -> list[int]: # in bytes - output = self.collective_rpc("determine_available_memory") - return output + return self.collective_rpc("determine_available_memory") def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: - output = self.collective_rpc("get_kv_cache_spec") - return output + return self.collective_rpc("get_kv_cache_spec") def execute_model( self, @@ -88,6 +94,13 @@ class Executor(ExecutorBase): args=(scheduler_output, )) return output[0] + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch") + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + output = self.collective_rpc("take_draft_token_ids") + return output[0] + @property def max_concurrent_batches(self) -> int: return 1 diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0db3bcd7fb408bea7eb64bd2ca2f7ff09aa55c21..12e79ff165f4e5ac56572acf4fae09a0621c3df2 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method, get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -191,6 +191,16 @@ class MultiprocExecutor(Executor): outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch", + unique_reply_rank=self.output_rank) + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + # OPTIMIZATION: Get output only from a single worker (output_rank) + outputs = self.collective_rpc("take_draft_token_ids", + unique_reply_rank=self.output_rank) + return outputs[0] + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, @@ -236,12 +246,17 @@ class MultiprocExecutor(Executor): dequeue_timeout = None if deadline is None else ( deadline - time.monotonic()) - if non_block: + if self.io_thread_pool is not None: + # We must consume worker_response_mq from a single thread. result = self.io_thread_pool.submit( # type: ignore get_response, w, dequeue_timeout, self.shutdown_event) - else: + if not non_block: + result = result.result() + elif not non_block: result = get_response(w, dequeue_timeout) - + else: + raise RuntimeError("non_block can only be used when" + " max_concurrent_batches > 1") responses.append(result) return responses diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index c05ad1966d6112ff71b1644e4abe8422116585a2..8394ae788ab016142456ca93a185bc3b829873f7 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -8,6 +8,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput @@ -64,7 +65,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: """Execute the model on the Ray workers. diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 429416afa2483ba281e1699f476eb6b8b76fcc50..a3e4d393e4d2054ee4768b40747abc990cb774f7 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,6 +11,7 @@ from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) @@ -203,6 +204,28 @@ class MambaSpec(KVCacheSpec): return self.page_size_bytes +@dataclass(frozen=True) +class EncoderOnlyAttentionSpec(AttentionSpec): + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # Encoder-only layers do not need KV cache + return 0 + + +@dataclass(frozen=True) +class CrossAttentionSpec(AttentionSpec): + """ + KV cache spec for cross-attention layers in encoder-decoder models. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # For cross-attention, we need to cache encoder states + # Get encoder length (e.g., 1500 for Whisper). + max_encoder_len = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len(vllm_config.model_config) + return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py index 61ba5d66cb31a473e43d7cfd6a6de12e23fc4491..a43cf9ce255e6c00ca82cca0de12ba01a197929b 100644 --- a/vllm/v1/metrics/prometheus.py +++ b/vllm/v1/metrics/prometheus.py @@ -36,7 +36,7 @@ def setup_multiprocess_prometheus(): "and vLLM will properly handle cleanup.") -def get_prometheus_registry(): +def get_prometheus_registry() -> CollectorRegistry: """Get the appropriate prometheus registry based on multiprocessing configuration. diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 7d7cd0c94dd0411ad777f2f1dfe720e008862dc1..f8d6b24702f3c15f3852cfcc9d3bbe0c8bb948c8 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -94,9 +94,6 @@ class ModelRunnerOutput: # each request due to speculative/jump decoding. sampled_token_ids: list[list[int]] - # num_reqs x num_spec_tokens - spec_token_ids: Optional[list[list[int]]] - # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] @@ -117,10 +114,18 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +@dataclass +class DraftTokenIds: + + # [num_reqs] + req_ids: list[str] + # num_reqs x num_draft_tokens + draft_token_ids: list[list[int]] + + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 28af720d05fd16ee0eb58752d16482721409b42a..46506d272e90a507beea90b5629b38fee6a3f177 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -6,15 +6,40 @@ from typing import Optional import torch from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +@dataclass +class PoolingCursor: + index: list[int] + first_token_indices_gpu: torch.Tensor + last_token_indices_gpu: torch.Tensor + prompt_lens_cpu: torch.Tensor + num_scheduled_tokens_cpu: torch.Tensor + + def __getitem__(self, indices: slice): + return PoolingCursor( + index=self.index[indices], + first_token_indices_gpu=self.first_token_indices_gpu[indices], + last_token_indices_gpu=self.last_token_indices_gpu[indices], + prompt_lens_cpu=self.prompt_lens_cpu[indices], + num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], + ) + + def is_partial_prefill(self): + return not torch.all( + self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) @dataclass class PoolingMetadata: """Tensors for pooling.""" - - prompt_lens: torch.Tensor + prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] + pooling_cursor: Optional[PoolingCursor] = None def __getitem__(self, indices: slice): return PoolingMetadata( @@ -22,4 +47,31 @@ class PoolingMetadata: prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], + pooling_cursor=None + if self.pooling_cursor is None else self.pooling_cursor[indices], ) + + def build_pooling_cursor(self, num_scheduled_tokens: list[int], + device: torch.device): + self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, + self.prompt_lens, device) + + +def build_pooling_cursor(num_scheduled_tokens: list[int], + prompt_lens: torch.Tensor, device: torch.device): + assert len(prompt_lens) == len(num_scheduled_tokens) + + n_seq = len(num_scheduled_tokens) + index = list(range(n_seq)) + num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu") + cumsum = torch.zeros(n_seq + 1, + dtype=torch.int64, + pin_memory=pin_memory, + device="cpu") + torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:]) + cumsum = cumsum.to(device, non_blocking=True) + return PoolingCursor(index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 4e99a9ccef46e09cc771d1a188c747864b4dce48..ad7477241ebbd7909c17e85e304706a7c3eb367b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -6,10 +6,9 @@ import time from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.utils import is_list_of from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) from vllm.v1.structured_output.request import StructuredOutputRequest @@ -26,14 +25,12 @@ class Request: self, request_id: str, prompt_token_ids: list[int], - multi_modal_kwargs: Optional[list[MultiModalKwargsItem]], - multi_modal_hashes: Optional[list[str]], - multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: Optional[SamplingParams], pooling_params: Optional[PoolingParams], eos_token_id: Optional[int], client_index: int = 0, arrival_time: Optional[float] = None, + mm_features: Optional[list[MultiModalFeatureSpec]] = None, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, @@ -89,16 +86,14 @@ class Request: self.cache_salt: Optional[str] = cache_salt # Multi-modal related - self.mm_positions = multi_modal_placeholders or [] - self.mm_kwargs = multi_modal_kwargs or [] - self.mm_hashes: list[str] = multi_modal_hashes or [] - self.num_encoder_inputs = len(self.mm_kwargs) + self.mm_features = mm_features or [] + self.num_encoder_inputs = len(self.mm_features) self.has_encoder_inputs = self.num_encoder_inputs > 0 - - # Sanity check - assert len(self.mm_kwargs) == len(self.mm_positions) - if self.mm_hashes: - assert len(self.mm_kwargs) == len(self.mm_hashes) + # TODO(sfeng33): Remove these legacy fields after clearing out all + # references in scheduler and model runner + self.mm_positions = [f.mm_position for f in self.mm_features] + self.mm_kwargs = [f.data for f in self.mm_features] + self.mm_hashes = [f.identifier for f in self.mm_features] # Read-only views # Prevent directly appending to these lists since @@ -126,20 +121,11 @@ class Request: cls, request: EngineCoreRequest, block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] ) -> "Request": - if request.mm_kwargs is not None: - mm_kwargs_lst = list(request.mm_kwargs) - assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), ( - "mm_kwargs was not updated in EngineCore.add_request") - else: - mm_kwargs_lst = None - return cls( request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, - multi_modal_kwargs=mm_kwargs_lst, - multi_modal_hashes=request.mm_hashes, - multi_modal_placeholders=request.mm_placeholders, + mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 24387ab793906c4b3c68bc5d5dc499afa91c9830..60f9c0bdb63133ab9a76dbecc26e91758205a415 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional, TypeVar import torch +from vllm import SamplingParams from vllm.v1.sample.logits_processor.interface import (BatchUpdate, LogitsProcessor, MoveDirectionality) @@ -12,6 +13,8 @@ from vllm.v1.sample.logits_processor.interface import (BatchUpdate, if TYPE_CHECKING: from vllm.config import VllmConfig +T = TypeVar("T") + class MinPLogitsProcessor(LogitsProcessor): @@ -53,29 +56,37 @@ class MinPLogitsProcessor(LogitsProcessor): # Process added requests. for index, params, _, _ in batch_update.added: min_p = params.min_p - if self.min_p_cpu[index] != min_p: + min_p_before = self.min_p_cpu[index] + if min_p_before != min_p: needs_update = True self.min_p_cpu[index] = min_p - if min_p: - self.min_p_count += 1 + if min_p and not min_p_before: + self.min_p_count += 1 + elif not min_p and min_p_before: + self.min_p_count -= 1 if self.min_p_count: # Process removed requests. - needs_update |= bool(batch_update.removed) - for index in batch_update.removed: - if self.min_p_cpu[index]: - self.min_p_count -= 1 + if batch_update.removed: + needs_update = True + for index in batch_update.removed: + if self.min_p_cpu[index]: + self.min_p_cpu[index] = 0 + self.min_p_count -= 1 - # Process moved requests, unidirectional (a->b) and swap (a<->b) + # Process moved requests, unidirectional (a->b) and swap (a<->b). for adx, bdx, direct in batch_update.moved: - change = (min_p_a := - self.min_p_cpu[adx]) != (min_p_b := - self.min_p_cpu[bdx]) - needs_update |= change - if change: + min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx] + if min_p_a != min_p_b: + needs_update = True self.min_p_cpu[bdx] = min_p_a if direct == MoveDirectionality.SWAP: self.min_p_cpu[adx] = min_p_b + if direct == MoveDirectionality.UNIDIRECTIONAL: + if min_p_a: + self.min_p_cpu[adx] = 0 + if min_p_b: + self.min_p_count -= 1 # Update tensors if needed. size = batch_update.batch_size @@ -122,49 +133,15 @@ class LogitBiasLogitsProcessor(LogitsProcessor): return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - needs_update: bool = False - # Process added requests. - for index, params, _, _ in batch_update.added: - if lb := params.logit_bias: - self.biases[index] = lb - needs_update = True - else: - # Drop biases metadata at batch index - if self.biases.pop(index, None) is not None: - # If a new request replaces an old request which - # specified biases, we should update processor tensors - needs_update = True - - if self.biases: - # Process removed requests. - for index in batch_update.removed: - if self.biases.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and swap (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.biases.pop(a_index, None)) is None: - if self.biases.pop(b_index, None) is not None: - needs_update = True - else: - self.biases[b_index] = a_entry - needs_update = True - else: - a_entry = self.biases.pop(a_index, None) - if (b_entry := self.biases.pop(b_index, None)) is not None: - self.biases[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.biases[b_index] = a_entry - needs_update = True + needs_update = process_dict_updates( + self.biases, batch_update, + lambda params, _, __: params.logit_bias or None) # Update tensors if needed. if needs_update: - reqs, tok_ids, biases = [], [], [] + reqs: list[int] = [] + tok_ids: list[int] = [] + biases: list[float] = [] for req, lb in self.biases.items(): reqs.extend([req] * len(lb)) tok_ids.extend(lb.keys()) @@ -208,52 +185,18 @@ class MinTokensLogitsProcessor(LogitsProcessor): of the argmax operation in greedy sampling.""" return False - def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = False - - if batch_update: - # Process added requests. - for index, params, _, output_tok_ids in batch_update.added: - if ((min_tokens := params.min_tokens) - and len(output_tok_ids) < min_tokens): - # Replace request metadata at batch index - self.min_toks[index] = (min_tokens, output_tok_ids, - params.all_stop_token_ids) - needs_update = True - else: - # Drop min_toks metadata at batch index - if self.min_toks.pop(index, None) is not None: - # If a new request replaces an old request which - # specified min_toks, we should update processor tensors - needs_update = True - - if self.min_toks: - # Process removed requests. - for index in batch_update.removed: - if self.min_toks.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and - # swapped (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.min_toks.pop(a_index, - None)) is None: - if self.min_toks.pop(b_index, None) is not None: - needs_update = True - else: - self.min_toks[b_index] = a_entry - needs_update = True - else: - a_entry = self.min_toks.pop(a_index, None) - if (b_entry := self.min_toks.pop(b_index, - None)) is not None: - self.min_toks[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.min_toks[b_index] = a_entry - needs_update = True + @staticmethod + def add_request( + params: SamplingParams, _: list[int], output_tok_ids: list[int] + ) -> Optional[tuple[int, Sequence[int], set[int]]]: + min_tokens = params.min_tokens + if not min_tokens or len(output_tok_ids) >= min_tokens: + return None + return min_tokens, output_tok_ids, params.all_stop_token_ids + def update_state(self, batch_update: Optional[BatchUpdate]): + needs_update = process_dict_updates(self.min_toks, batch_update, + self.add_request) if self.min_toks: # Check for any requests that have attained their min tokens. to_remove = tuple(index for index, (min_toks, out_tok_ids, @@ -287,3 +230,44 @@ class MinTokensLogitsProcessor(LogitsProcessor): # Inhibit EOS token for requests which have not reached min length logits[self.logits_slice] = -float("inf") return logits + + +def process_dict_updates( + req_entries: dict[int, T], batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]] +) -> bool: + """Utility function to update dict state for sparse LogitsProcessors.""" + + if not batch_update: + # Nothing to do. + return False + + updated = False + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + if (state := new_state(params, prompt_tok_ids, + output_tok_ids)) is not None: + req_entries[index] = state + updated = True + elif req_entries.pop(index, None) is not None: + updated = True + + if req_entries: + # Process removed requests. + for index in batch_update.removed: + if req_entries.pop(index, None): + updated = True + + # Process moved requests, unidirectional (a->b) and + # swapped (a<->b) + for a_index, b_index, direct in batch_update.moved: + a_entry = req_entries.pop(a_index, None) + b_entry = req_entries.pop(b_index, None) + if a_entry is not None: + req_entries[b_index] = a_entry + updated = True + if b_entry is not None: + updated = True + if direct == MoveDirectionality.SWAP: + req_entries[a_index] = b_entry + + return updated diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 12b4db24bff88b52c6991da55b8d6348315f4c91..683fc7c00dfb20d02a5c0634fbdb911226981fc2 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -44,10 +44,16 @@ class BatchUpdate: # Key assumption: the `output_tok_ids` list (which is an element of each # tuple in `added`) is a reference to the request's running output tokens # list; via this reference, the logits processors always see the latest - # list of generated output tokens + # list of generated output tokens. + # + # NOTE: + # * Added or moved requests may replace existing requests with the same + # index. + # * Operations should be processed in the following order: + # - removed, added, moved removed: Sequence[RemovedRequest] - moved: Sequence[MovedRequest] added: Sequence[AddedRequest] + moved: Sequence[MovedRequest] class LogitsProcessor(ABC): @@ -59,6 +65,11 @@ class LogitsProcessor(ABC): @abstractmethod def apply(self, logits: torch.Tensor) -> torch.Tensor: + """Apply LogitsProcessor to batch logits tensor. + + The updated tensor must be returned but may be + modified in-place. + """ raise NotImplementedError @abstractmethod @@ -80,7 +91,7 @@ class LogitsProcessor(ABC): to each forward pass. Args: - batch_update is non-None iff there have been - changes to the batch makeup. + batch_update: Non-None iff there have been changes + to the batch makeup. """ raise NotImplementedError diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index 0f58b524969563b5866d88d24499026daa71cbce..31cece58c7db5f3cd124370a7cb9dfa1da4fca45 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -50,6 +50,10 @@ class BatchUpdateBuilder: self.added = added or [] self._is_removed_sorted = False + # Used to track changes in the pooling case + # where we don't populate the added list. + self.batch_changed = False + def _ensure_removed_sorted(self) -> None: """Sort removed request indices in descending order. @@ -80,6 +84,7 @@ class BatchUpdateBuilder: raise RuntimeError("Cannot register new removed request after" " self.removed has been read.") self._removed.append(index) + self.batch_changed = True def has_removed(self) -> bool: return bool(self._removed) @@ -98,9 +103,15 @@ class BatchUpdateBuilder: return self._removed.pop() return None - def _is_update(self) -> bool: - """True if there is a batch state change""" - return any((self._removed, self.moved, self.added)) + def reset(self) -> bool: + """Returns True if there were any changes to the batch.""" + self._is_removed_sorted = False + self._removed.clear() + self.moved.clear() + self.added.clear() + batch_changed = self.batch_changed + self.batch_changed = False + return batch_changed def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: """Generate a logitsprocs batch update data structure and reset @@ -114,7 +125,8 @@ class BatchUpdateBuilder: """ # Reset removal-sorting logic self._is_removed_sorted = False - if not self._is_update(): + self.batch_changed = False + if not any((self._removed, self.moved, self.added)): # No update; short-circuit return None # Build batch state update diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e0434c8f3d7130f1988704dc1663cdce0313c609..7bd4a5a380ac06193af2b1b88a1c2874751e95ab 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -8,6 +8,7 @@ import torch.nn as nn from packaging import version from vllm import envs +from vllm.config import LogprobsMode from vllm.logger import init_logger from vllm.platforms import current_platform @@ -28,9 +29,16 @@ class TopKTopPSampler(nn.Module): Implementations may update the logits tensor in-place. """ - def __init__(self): + def __init__( + self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None: super().__init__() - if current_platform.is_cuda(): + self.logprobs_mode = logprobs_mode + # flashinfer optimization does not apply if intermediate + # logprobs/logits after top_k/top_p need to be returned + if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS, + LogprobsMode.PROCESSED_LOGPROBS + ) and current_platform.is_cuda(): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ if version.parse(flashinfer_version) < version.parse("0.2.3"): @@ -63,10 +71,12 @@ class TopKTopPSampler(nn.Module): "native implementation of top-p & top-k sampling. For the " "best performance, please install FlashInfer.") self.forward = self.forward_native - elif current_platform.is_tpu(): - self.forward = self.forward_tpu else: self.forward = self.forward_native + if current_platform.is_tpu(): + self.apply_top_k_top_p = apply_top_k_top_p_tpu + else: + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, @@ -74,15 +84,20 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ PyTorch-native implementation of top-k and top-p sampling. The logits tensor may be updated in-place. """ - logits = apply_top_k_top_p(logits, k, p) + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + logits_to_return = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return random_sample(probs, generators), logits_to_return def forward_cuda( self, @@ -90,34 +105,24 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """More optimized implementation for top-k and top-p sampling.""" - if k is None and p is None: - # We prefer `random_sample` over `flashinfer_sample` when sorting is - # not needed. This is because `random_sample` does not require - # CPU-GPU synchronization while `flashinfer_sample` does. - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - if generators: - logger.warning_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + if (k is None and p is None) or generators: + if generators: + logger.warning_once("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) + assert self.logprobs_mode not in ( + LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS + ), "FlashInfer does not support returning logits/logprobs" # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. - return flashinfer_sample(logits.contiguous(), k, p, generators) - - def forward_tpu( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> torch.Tensor: - logits = apply_top_k_top_p_tpu(logits, k, p) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return flashinfer_sample(logits.contiguous(), k, p, generators), None def apply_top_k_top_p_tpu( diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index b2354c53302adb8e8df5b8cff281657b58992fbc..3d5e59addfcfaee615f36ca661263fc2db0f21eb 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -68,7 +68,7 @@ class RejectionSampler(nn.Module): different requests are flattened into a single tensor because this is the shape of the output logits. NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids_tensor (torch.Tensor): + bonus_token_ids (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all proposed tokens are accepted. We generate the bonus tokens @@ -365,9 +365,14 @@ def generate_uniform_probs( A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ + # NOTE(woosuk): We deliberately use float64 instead of float32 here + # because when using float32, there's a non-negligible chance that + # uniform_prob is sampled to be exact 0.0 as reported in + # https://github.com/pytorch/pytorch/issues/16706. Using float64 + # mitigates the issue. uniform_probs = torch.rand( (num_tokens, ), - dtype=torch.float32, + dtype=torch.float64, device=device, ) start_idx = 0 @@ -593,17 +598,10 @@ def sample_recovered_tokens_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - 0) prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, + mask=((vocab_offset < vocab_size) & + (vocab_offset != draft_token_id)), other=0) else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + @@ -623,9 +621,3 @@ def sample_recovered_tokens_kernel( other=float("-inf")) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - orig_prob) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 82f51298f1b59ec7ecf59610285f55a081aeda0b..546531a91610fc077418baa68bad88ae29889bdf 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that samples the next tokens from the model's outputs.""" +from typing import Optional + import torch import torch.nn as nn @@ -18,10 +20,50 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): - - def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): + """ + A layer that samples the next tokens from the model's outputs + with the following steps in order: + + 1. If logprobs are requested: + a) If `logprobs_mode` is `raw_logprobs`, compute logprobs + as the final logprobs to return. + b) If `logprobs_mode` is `raw_logits`, clone the logits + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. + 5. Apply logit processors which are not argmax-invariant, + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: + a) If not `all_random`, perform greedy sampling. If `all_greedy`, + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. + c) Apply logit processors which are argmax-invariant, by default + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. + f) If `all_random` or temperature >= epsilon (1e-5), return the + randomly sampled tokens and final logprobs if requested. Else, + return the greedily sampled tokens and logprobs if requested. + 8. Gather the logprobs of the top `max_num_logprobs` and sampled token + (if requested). Note that if the sampled token is within the top + `max_num_logprobs`, the logprob will be eventually merged in + `LogprobsProcessor` during output processing. Therefore, the + final output may contain either `max_num_logprobs + 1` or + `max_num_logprobs` logprobs. + 9. Return the final `SamplerOutput`. + """ + + def __init__(self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS): super().__init__() - self.topk_topp_sampler = TopKTopPSampler() + self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) self.pin_memory = is_pin_memory_available() self.logprobs_mode = logprobs_mode @@ -34,13 +76,11 @@ class Sampler(nn.Module): # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). - # TODO(rob): provide option for logprobs post sampling. - # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == "raw_logprobs": + if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS: raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "raw_logits": + elif self.logprobs_mode == LogprobsMode.RAW_LOGITS: raw_logprobs = logits.clone() # Use float32 for the logits. @@ -51,21 +91,16 @@ class Sampler(nn.Module): logits = self.apply_bad_words(logits, sampling_metadata) # Apply logits processors which can impact greedy sampling - for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): + for processor in sampling_metadata.logitsprocs.non_argmax_invariant: logits = processor.apply(logits) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Get the process logprobs or logits. - if num_logprobs is not None: - if self.logprobs_mode == "processed_logprobs": - raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "processed_logits": - raw_logprobs = logits.clone() - # Sample the next token. - sampled = self.sample(logits, sampling_metadata) + sampled, processed_logprobs = self.sample(logits, sampling_metadata) + if processed_logprobs is not None: + raw_logprobs = processed_logprobs # Convert sampled token ids to int64 (long) type to ensure compatibility # with subsequent operations that may use these values as indices. # This conversion is necessary because FlashInfer sampling operations @@ -105,7 +140,7 @@ class Sampler(nn.Module): self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Sample logits based on sampling metadata. The various logits processing functions called in this method @@ -119,7 +154,13 @@ class Sampler(nn.Module): else: greedy_sampled = self.greedy_sample(logits) if sampling_metadata.all_greedy: - return greedy_sampled + processed_logprobs = None + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + processed_logprobs = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + processed_logprobs = self.compute_logprobs(logits) + return greedy_sampled, processed_logprobs assert sampling_metadata.temperature is not None @@ -132,7 +173,7 @@ class Sampler(nn.Module): logits = processor.apply(logits) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, processed_logprobs = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, @@ -140,7 +181,7 @@ class Sampler(nn.Module): ) if greedy_sampled is None: - return random_sampled + return random_sampled, processed_logprobs sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, @@ -148,7 +189,7 @@ class Sampler(nn.Module): random_sampled, out=greedy_sampled, # Reuse tensor ) - return sampled + return sampled, processed_logprobs def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 2c9f4892bc247a77fdca88efb3193c55c9c4bd94..e84136e3a6d07b994734b72a6288b0c49ff98a5c 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -65,7 +65,7 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, _ = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, @@ -89,7 +89,7 @@ class Sampler(nn.Module): Gather logprobs for topk and sampled/prompt token. Args: - logits: (num tokens) x (vocab) tensor + logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 2857d8ef4290967d421b4a9271e98e71526a3c0f..c8375d6f15517e75426a39ac52c232a6a88958d6 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -18,12 +18,15 @@ from msgspec import msgpack from vllm import envs from vllm.logger import init_logger +# yapf: disable from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalBatchedField, MultiModalFieldConfig, MultiModalFieldElem, MultiModalFlatField, MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField, NestedTensors) +# yapf: enable from vllm.v1.engine import UtilityResult logger = init_logger(__name__) @@ -116,12 +119,11 @@ class MsgpackEncoder: if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) + if isinstance(obj, MultiModalKwargsItems): + return self._encode_mm_items(obj) + if isinstance(obj, MultiModalKwargs): - return [ - self._encode_mm_item(item) - for itemlist in obj._items_by_modality.values() - for item in itemlist - ] + return self._encode_mm_kwargs(obj) if isinstance(obj, UtilityResult): result = obj.result @@ -183,6 +185,12 @@ class MsgpackEncoder: dtype = str(obj.dtype).removeprefix("torch.") return dtype, obj.shape, data + def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]: + return { + modality: [self._encode_mm_item(item) for item in itemlist] + for modality, itemlist in items.items() + } + def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]: return [self._encode_mm_field_elem(elem) for elem in item.values()] @@ -200,6 +208,12 @@ class MsgpackEncoder: self._encode_mm_field(elem.field), } + def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: + return { + modality: self._encode_nested_tensors(data) + for modality, data in kw.items() + } + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): return self._encode_tensor(nt) @@ -260,8 +274,10 @@ class MsgpackDecoder: return slice(*obj) if issubclass(t, MultiModalKwargsItem): return self._decode_mm_item(obj) + if issubclass(t, MultiModalKwargsItems): + return self._decode_mm_items(obj) if issubclass(t, MultiModalKwargs): - return MultiModalKwargs(self._decode_mm_items(obj)) + return self._decode_mm_kwargs(obj) if t is UtilityResult: return self._decode_utility_result(obj) return obj @@ -315,8 +331,11 @@ class MsgpackDecoder: # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) - def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]: - return [self._decode_mm_item(v) for v in obj] + def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: + return MultiModalKwargsItems({ + modality: [self._decode_mm_item(item) for item in itemlist] + for modality, itemlist in obj.items() + }) def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: return MultiModalKwargsItem.from_elems( @@ -339,6 +358,12 @@ class MsgpackDecoder: obj["field"] = factory_meth(None, *field_args).field return MultiModalFieldElem(**obj) + def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: + return MultiModalKwargs({ + modality: self._decode_nested_tensors(data) + for modality, data in obj.items() + }) + def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cb567987d1301cd44ba72c922f15ebeca489e041..945f0ccc312a7e38f78ff8c6db3d69c9d07680a4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,8 @@ import ast from dataclasses import replace -from typing import Optional, Any +from importlib.util import find_spec +from typing import Optional, Protocol, Any import numpy as np @@ -23,8 +24,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -# from vllm.v1.attention.backends.rocm_aiter_fa import ( -# AiterFlashAttentionMetadata) from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, TreeAttentionMetadataBuilder) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata @@ -39,6 +38,17 @@ logger = init_logger(__name__) PADDING_SLOT_ID = -1 +class EagleAttentionMetadata(Protocol): + # Required attributes + num_actual_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + class EagleProposer: def __init__( @@ -107,6 +117,20 @@ class EagleProposer: dtype=self.dtype, device=device) + # Determine allowed attention backends once during initialization. + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + else: + self.allowed_attn_types = (FlashAttentionMetadata, + TreeAttentionMetadata) + # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree self.tree_choices: list[tuple[int, @@ -178,7 +202,7 @@ class EagleProposer: for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -231,7 +255,7 @@ class EagleProposer: hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method == "deepseek_mtp": + if self.method in ("deepseek_mtp", "ernie_mtp"): last_hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states @@ -263,19 +287,7 @@ class EagleProposer: # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - - # On ROCm, both AiterFlashAttention and TritonAttention - # support multi-token eagle spec decode. - if current_platform.is_rocm(): - assert isinstance( - attn_metadata, - (TritonAttentionMetadata, # AiterFlashAttentionMetadata, - FlashAttentionMetadata)) - else: - # Currently, only FlashAttention supports multi-token eagle spec - # decode. This is because the code below makes assumptions about - # attn_metadata attributes available. - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert isinstance(attn_metadata, self.allowed_attn_types) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -286,7 +298,7 @@ class EagleProposer: hidden_states = hidden_states[last_token_indices] if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size @@ -538,7 +550,7 @@ class EagleProposer: num_tokens, -1) if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_tokens) else: @@ -599,19 +611,19 @@ class EagleProposer: """ # E.g. # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] + # [0, q1, q1 + q2, q1 + q2 + q3] # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] # num_rejected_tokens: [n1, n2, n3] # This function computes the intermediate values: # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] # And returns: # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -655,9 +667,9 @@ class EagleProposer: old_query_start_locs_expanded = np.repeat( query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) @@ -674,6 +686,7 @@ class EagleProposer: num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, @@ -707,21 +720,19 @@ class EagleProposer: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ - and self.method != "deepseek_mtp" \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + and self.method != "deepseek_mtp" \ + and self.model.model.embed_tokens.weight.shape \ + == target_language_model.model.embed_tokens.weight.shape: logger.info( - "Assuming the EAGLE head shares the same vocab embedding" \ - " with the target model." - ) + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") del self.model.model.embed_tokens self.model.model.embed_tokens = ( target_language_model.model.embed_tokens) else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" \ - " from the target model." - ) + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 309fd926aecd7b402881dda9dd1f2fc8bbf01472..3e90179e78d99d8d8e45547e2c6f80658e9d479e 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,12 +38,14 @@ class MedusaProposer: self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> list[list[int]]: # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks, None) # Get draft tokens and transpose the result + # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU + # synchronization. draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] return [list(row) for row in zip(*draft_tokens)] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 63604a335d9f08b97bebe7ef6e0644862200f7d3..57854cc1120413ed63a45053a6b9064064a2a479 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -108,6 +108,14 @@ class StructuredOutputManager: tokenizer=self.tokenizer, vocab_size=vocab_size, ) + elif backend == "lm-format-enforcer": + from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501 + LMFormatEnforcerBackend) + self.backend = LMFormatEnforcerBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") @@ -267,7 +275,7 @@ class StructuredOutputManager: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None # by default, we should always advance - # for cases that doesn't uses thinking mode. + # for cases that don't use thinking mode. if self.reasoner is not None: structured_req = request.structured_output_request @@ -276,7 +284,7 @@ class StructuredOutputManager: # Check if reasoning ends in *this* step if self.reasoner.is_reasoning_end(request.all_token_ids): - # Reasoning just ended, so we shouldn't advanced til + # Reasoning just ended, so we shouldn't advance til # next pass structured_req.reasoning_ended = True diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py new file mode 100644 index 0000000000000000000000000000000000000000..2279a1c8c8a00552770f76f2dac2a04c700e8734 --- /dev/null +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import ast +import json +from dataclasses import dataclass, field +from functools import lru_cache +from typing import TYPE_CHECKING + +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) + +if TYPE_CHECKING: + import lmformatenforcer + import lmformatenforcer.integrations.vllm as lmfe_vllm +else: + lmformatenforcer = LazyLoader("lmformatenforcer", globals(), + "lmformatenforcer") + lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(), + "lmformatenforcer.integrations.vllm") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase, + vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData: + return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( + tokenizer, use_bitmask=True, vocab_size=vocab_size) + + +@dataclass +class LMFormatEnforcerGrammar(StructuredOutputGrammar): + token_enforcer: lmformatenforcer.TokenEnforcer + current_tokens_prefix: list[int] = field(default_factory=list) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + original_len = len(self.current_tokens_prefix) + for token in tokens: + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix).is_token_allowed(token): + # Rollback partial updates to ensure atomicity. + del self.current_tokens_prefix[original_len:] + return False + self.current_tokens_prefix.append(token) + return True + + def validate_tokens(self, tokens: list[int]) -> list[int]: + for prefix_length in range(len(tokens)): + prefix = tokens[:prefix_length] + next_token = tokens[prefix_length] + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix + + prefix).is_token_allowed(next_token): + break + else: + return tokens + + return tokens[:prefix_length] + + def rollback(self, num_tokens: int) -> None: + self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens] + + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + allowed_tokens = self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix) + bitmask[batch_index] = allowed_tokens.allowed_tokens + + def is_terminated(self) -> bool: + # We are considered terminated if the prefix ends with eos_token_id + return_value = len( + self.current_tokens_prefix) > 0 and self.current_tokens_prefix[ + -1] == self.token_enforcer.eos_token_id + return return_value + + def reset(self): + self.current_tokens_prefix = [] + + +@dataclass +class LMFormatEnforcerBackend(StructuredOutputBackend): + + def __post_init__(self): + self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + self.tokenizer, self.vocab_size) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + character_level_parser: lmformatenforcer.CharacterLevelParser + if request_type == StructuredOutputOptions.JSON: + spec_dict = json.loads(grammar_spec) + character_level_parser = lmformatenforcer.JsonSchemaParser( + spec_dict) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + character_level_parser = lmformatenforcer.JsonSchemaParser(None) + elif request_type == StructuredOutputOptions.REGEX: + character_level_parser = lmformatenforcer.RegexParser(grammar_spec) + elif request_type == StructuredOutputOptions.CHOICE: + choices = ast.literal_eval(grammar_spec) + character_level_parser = lmformatenforcer.UnionParser( + [lmformatenforcer.StringParser(choice) for choice in choices]) + else: + raise ValueError( + "Invalid request type for LM Format Enforcer backend" + f"({request_type!s})") + max_rollback_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config is not None else 0) + + if max_rollback_tokens > 0: + raise ValueError( + "LM Format Enforcer backend does not support speculative tokens" + ) + + token_enforcer = lmformatenforcer.TokenEnforcer( + tokenizer_data=self.tokenizer_data, + parser=character_level_parser, + ) + return LMFormatEnforcerGrammar(token_enforcer) + + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + return torch.full( + (max_num_seqs, (self.vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + + def destroy(self): + pass + + +def validate_structured_output_request_lm_format_enforcer( + params: SamplingParams): + if params.guided_decoding is None: + return + + gd_params = params.guided_decoding + + if gd_params.regex: + return + elif gd_params.json: + if isinstance(gd_params.json, str): + try: + # make sure schema is valid json + json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + try: + json.dumps(gd_params.json) + except Exception as e: + raise ValueError( + f"Error serializing guided decoding jsonschema: {e}" + ) from e + return + elif gd_params.choice: + return + elif gd_params.grammar: + raise ValueError("LM Format Enforcer guided decoding backend " + "does not support grammar specifications") diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index d500783aa4b30c6fd7c969178cc6c246d3646796..9a53aa7a1ad1088bb635d8f559126e2b96f1efe9 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -110,7 +110,7 @@ class StructuredOutputBackend(ABC): Args: request_type (StructuredOutputOptions): The type of structured - output request. + output request. grammar_spec (str): The grammar specification to compile. Returns: @@ -124,7 +124,7 @@ class StructuredOutputBackend(ABC): Args: max_num_seqs (int): The maximum number of sequences for which - to allocate the bitmask. + to allocate the bitmask. """ @abstractmethod diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b5750c82db023e5ad0efee33007abfe45b85cca6..8f9face6fbf2ee7a3b3d416ce96b7345c12e86d8 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -96,6 +96,35 @@ class ConstantList(Generic[T], Sequence): return f"ConstantList({self._x})" +class CpuGpuBuffer: + + def __init__( + self, + *args, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + ): + self.cpu = torch.zeros(*args, + dtype=dtype, + device="cpu", + pin_memory=pin_memory) + self.np = self.cpu.numpy() + self.gpu = self.cpu.to(device) + + def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: + if n is None: + return self.gpu.copy_(self.cpu, non_blocking=True) + return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) + + def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor: + """NOTE: Because this method is non-blocking, explicit synchronization + is needed to ensure the data is copied to CPU.""" + if n is None: + return self.cpu.copy_(self.gpu, non_blocking=True) + return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) + + def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 5662fc350e198761e79a3f95228e677649d47d3c..6ab5ce2748a4a05b6b74654a5a2147e6eaf16a85 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -110,7 +110,7 @@ class BlockTable: self.block_table_cpu.fill_(0) def get_device_tensor(self) -> torch.Tensor: - """Ruturns the device tensor of the block table.""" + """Returns the device tensor of the block table.""" return self.block_table def get_cpu_tensor(self) -> torch.Tensor: diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index a7180afbd64b52486a202f24338469c72fd6c3b1..feb49978d7518a1d82eb6f087298b5c2d3c55e4f 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import torch import torch.nn as nn @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner if TYPE_CHECKING: @@ -21,7 +22,8 @@ logger = init_logger(__name__) class CPUModelRunner(GPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) assert device == torch.device("cpu") assert self.speculative_config is None, "spec decode is not supported." @@ -41,7 +43,7 @@ class CPUModelRunner(GPUModelRunner): Args: scheduler_output: The scheduler output. """ - # Attention free models have zero kv_cache_goups, however models + # Attention free models have zero kv_cache_groups, however models # like Mamba are also attention free but use the kv_cache for # keeping its internal state. This is why we check the number # of kv_cache groups instead of solely checking @@ -71,8 +73,8 @@ class CPUModelRunner(GPUModelRunner): setattr(obj, device_attr_name, cpu_tensor) for k, v in vars(self).items(): - if k.endswith("_cpu") and isinstance(v, torch.Tensor): - replace_tensor(self, k, k[:-4]) + if isinstance(v, CpuGpuBuffer): + v.gpu = v.cpu for k, v in vars(self.input_batch).items(): if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor): @@ -108,17 +110,42 @@ class CPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: pass + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + return sampled_token_ids.tolist() + + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + # Note: For CPU backend, dp padding is not required for now. + return 0, None + + +@contextmanager +def _torch_cuda_wrapper(): + + class _EventPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + cuda_event = torch.cuda.Event + try: + torch.cuda.Event = _EventPlaceholder + yield + finally: + torch.cuda.Event = cuda_event + @contextmanager def _set_global_compilation_settings(config: VllmConfig): - import torch._inductor.config + import torch._inductor.config as torch_inductor_config inductor_config = config.compilation_config.inductor_compile_config + # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. + freezing_value = torch_inductor_config.freezing try: - # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. - freezing_value = torch._inductor.config.freezing if inductor_config.get("max_autotune", False): - torch._inductor.config.freezing = True + torch_inductor_config.freezing = True yield finally: - torch._inductor.config.freezing = freezing_value + torch_inductor_config.freezing = freezing_value diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 2dc28d93049ab508bdc74186531a15082cc047c9..b87c4fe09bb90d10a5d22ca87d21da1aa80dbdc7 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -43,8 +43,9 @@ class CPUWorker(Worker): # Setup OpenMP threads affinity. omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND if omp_cpuids == "auto" and platform.system() == "Linux": - if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC: - # For POWERPC SMT-8/4/2 + cpu_arch = current_platform.get_cpu_architecture() + if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X): + # For S390X/POWERPC SMT-8/4/2 self.local_omp_cpuid = self._get_autobind_cpu_ids( lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]) elif current_platform.get_cpu_architecture() == CpuArchEnum.X86: @@ -54,7 +55,14 @@ class CPUWorker(Worker): else: self.local_omp_cpuid = "all" else: - self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] + local_dp_rank = self.parallel_config.data_parallel_rank_local + omp_cpuids = omp_cpuids.split("|") + if local_dp_rank is not None: + world_size = self.parallel_config.world_size + omp_cpuids = omp_cpuids[local_dp_rank * + world_size:(local_dp_rank + 1) * + world_size] + self.local_omp_cpuid = omp_cpuids[self.rank] if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) @@ -132,7 +140,7 @@ class CPUWorker(Worker): """ allowed_numa_nodes, logical_cpu_list = \ - CpuPlatform.get_allowed_cpu_memory_node_list() + CpuPlatform.get_allowed_cpu_core_node_list() assert len(allowed_numa_nodes) >= self.parallel_config.world_size, ( f"No enough allowed NUMA nodes to bind threads of " f"{self.parallel_config.world_size} CPUWorkers. " @@ -161,7 +169,9 @@ class CPUWorker(Worker): # Reserve CPUs for other processes reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU if reserve_cpu_num is None: - reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0 + need_reserve = (self.parallel_config.world_size > 1 or + self.parallel_config.data_parallel_size_local > 1) + reserve_cpu_num = 1 if need_reserve else 0 assert len(logical_cpu_list) > reserve_cpu_num, ( f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) " f"should less than {len(logical_cpu_list)}.") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b295bf08da2d67bf80a932fd8e9294f2dcdf4e2f..3072f7a2f324d6ea72e1607fc7233e77c574b8e2 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,8 +10,8 @@ import torch from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalKwargsItem, + MultiModalKwargsItems, PlaceholderRange) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values @@ -33,6 +33,7 @@ class CachedRequestState: prompt_token_ids: list[int] mm_kwargs: list[MultiModalKwargsItem] mm_positions: list[PlaceholderRange] + mm_hashes: list[str] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -58,14 +59,15 @@ class CachedRequestState: @property @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " "removed in v0.13. Please use `mm_kwargs` instead.") - def mm_inputs(self) -> list[MultiModalKwargs]: - return [MultiModalKwargs([item]) for item in self.mm_kwargs] + def mm_inputs(self) -> list[MultiModalKwargsItems]: + return [ + MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs + ] def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: - return self.output_token_ids[idx - self.num_prompt_tokens] + return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: @@ -260,30 +262,27 @@ class InputBatch: Not applicable to pooling models. """ - # Detailed added request metadata is only required for non-pooling - # models, to support logitsprocs - assert request.sampling_params - # Fill the next empty index if there is one. if (new_req_index := self.batch_update_builder.pop_removed()) is None: # Append to end otherwise. new_req_index = self.num_reqs assert new_req_index < self.max_num_reqs - self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, request.prompt_token_ids, - request.output_token_ids)) + self.batch_update_builder.batch_changed = True + if request.sampling_params: + # Detailed added request metadata is only required for non-pooling + # models, to support logitsprocs. + self.batch_update_builder.added.append( + (new_req_index, request.sampling_params, + request.prompt_token_ids, request.output_token_ids)) + return new_req_index def add_request( self, request: "CachedRequestState", ) -> int: - if not self.is_pooling_model: - # New request index bookkeeping for autoregressive models. - req_index = self._register_add_request(request) - else: - req_index = self.num_reqs + req_index = self._register_add_request(request) req_id = request.req_id if req_index == len(self._req_ids): @@ -395,7 +394,7 @@ class InputBatch: self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids) else: - raise NotImplementedError(request) + raise NotImplementedError("Unrecognized request type") # Add request lora ID if request.lora_request: @@ -425,13 +424,25 @@ class InputBatch: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None - if not self.is_pooling_model: - # Autoregressive models require bookkeeping of removed requests to - # support logitsprocs. - self.batch_update_builder.removed_append(req_index) + + self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + lora_req_ids = self.lora_id_to_request_ids[lora_id] + lora_req_ids.discard(req_id) + if not lora_req_ids: + del self.lora_id_to_request_ids[lora_id] + del self.lora_id_to_lora_request[lora_id] + self.request_lora_mapping[req_index] = 0 + + if self.is_pooling_model: + self.pooling_params.pop(req_id, None) + return req_index + self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) @@ -445,28 +456,14 @@ class InputBatch: self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - self.lora_id_to_request_ids[lora_id].discard(req_id) - if len(self.lora_id_to_request_ids[lora_id]) == 0: - self.lora_id_to_request_ids.pop(lora_id) - self.lora_id_to_lora_request.pop(lora_id) - self.request_lora_mapping[req_index] = 0 - self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: - # For autoregressive models, track detailed request reordering info - # to support logitsprocs - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] self._req_ids[i1], self._req_ids[i2] =\ @@ -484,18 +481,6 @@ class InputBatch: self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -506,18 +491,41 @@ class InputBatch: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp - swap_dict_values(self.generators, i1, i2) - swap_dict_values(self.bad_words_token_ids, i1, i2) + self.block_table.swap_row(i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ self.request_lora_mapping[i2], self.request_lora_mapping[i1] + if self.is_pooling_model: + # Sampling and logits parameters don't apply to pooling models. + return + + # For autoregressive models, track detailed request reordering info + # to support logitsprocs. + self.batch_update_builder.moved.append( + (i1, i2, MoveDirectionality.SWAP)) + + self.temperature_cpu[i1], self.temperature_cpu[i2] = \ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] = \ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = \ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + + swap_dict_values(self.generators, i1, i2) + swap_dict_values(self.bad_words_token_ids, i1, i2) + if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[i1], \ self.allowed_token_ids_mask_cpu_tensor[i2] =\ self.allowed_token_ids_mask_cpu_tensor[i2], \ self.allowed_token_ids_mask_cpu_tensor[i1] - self.block_table.swap_row(i1, i2) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -525,21 +533,12 @@ class InputBatch: Any consecutive empty indices at the very end of the list are not filled. - Args: - empty_req_indices: empty indices which may be filled. - Returns: swaps: list of (from,to) swap tuples for moved requests empty_req_indices: indices not filled by condensation """ num_reqs = self.num_reqs - if self.is_pooling_model: - # Will be contiguous in pooling case, just trim the lists. - del self._req_ids[num_reqs:] - del self.req_output_token_ids[num_reqs:] - return - if not (empty_req_indices := self.batch_update_builder.removed): # All removed requests were replaced by added requests, or else no # requests were removed at all. No condense() needed @@ -567,11 +566,6 @@ class InputBatch: # Move active request down into empty request # index. self.batch_update_builder.pop_removed() - # Autoregressive models require detailed tracking of condense - # operations to support logitsprocs - self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) req_id = self._req_ids[last_req_index] output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None @@ -592,6 +586,21 @@ class InputBatch: self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table.move_row(last_req_index, empty_index) + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + if self.is_pooling_model: + last_req_index -= 1 + # Samping state not used by pooling models. + continue + + # Autoregressive models require detailed tracking of condense + # operations to support logitsprocs + self.batch_update_builder.moved.append( + (last_req_index, empty_index, + MoveDirectionality.UNIDIRECTIONAL)) + self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] @@ -606,9 +615,6 @@ class InputBatch: if generator is not None: self.generators[empty_index] = generator - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] - # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[ @@ -631,8 +637,9 @@ class InputBatch: """Apply any batch updates to sampling metadata.""" if self.is_pooling_model: - # Batch changes every step for pooling models. - self.sampling_metadata = self._make_sampling_metadata() + batch_changed = self.batch_update_builder.reset() + if batch_changed: + self.sampling_metadata = self._make_sampling_metadata() return # For non-pooling models - generate and apply logitsprocs update; @@ -719,13 +726,14 @@ class InputBatch: return PoolingMetadata( prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]).to(self.device), + self.num_prompt_tokens[:self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + num_reqs = self.num_reqs + max_prompt_len = self.num_prompt_tokens[:num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -733,11 +741,10 @@ class InputBatch: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. - for i in range(self.num_reqs): + for i in range(num_reqs): prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 595205b8b7ccb0bb3906ee5becc7058ae953d026..fa47c07bdd27afc2e2cc199d678fe4bafb97279d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import gc import itertools import time from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager +from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np @@ -34,7 +34,8 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import (BatchDescriptor, DPMetadata, set_forward_context, set_profilling) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, @@ -54,19 +55,19 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - make_kv_sharing_fast_prefill_attention_metadata, + create_fast_prefill_custom_backend, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, - SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, + LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -76,6 +77,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) @@ -83,9 +85,10 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.platforms import current_platform from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model -from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, - gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from .utils import (AttentionGroup, MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) if TYPE_CHECKING: import xgrammar as xgr @@ -138,9 +141,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self.is_pooling_model = model_config.pooler_config is not None - self.is_encoder_only_model = False - self.is_multimodal_raw_input_supported = ( - model_config.is_multimodal_raw_input_supported) + self.is_multimodal_raw_input_only_model = ( + model_config.is_multimodal_raw_input_only_model) + self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -150,6 +153,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = check_use_alibi(model_config) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn @@ -177,8 +182,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} self.use_aux_hidden_state_outputs = False # Set up speculative decoding. @@ -244,24 +249,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - self.query_start_loc = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device=self.device) - self.seq_lens = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - - # None in the first PP rank. The rest are set after load_model. - self.intermediate_tensors: Optional[IntermediateTensors] = None + self.input_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, + dtype=torch.int64) + self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, + dtype=torch.int32) + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -275,23 +273,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 - self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), - dtype=torch.int64, - device=self.device) - self.mrope_positions_cpu = torch.zeros( - (3, self.max_num_tokens + 1), - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.mrope_positions_np = self.mrope_positions_cpu.numpy() - - # Only relevant for models using ALiBi (e.g, MPT) - self.use_alibi = check_use_alibi(model_config) + self.mrope_positions = self._make_buffer( + (3, self.max_num_tokens + 1), dtype=torch.int64) - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context @@ -299,28 +285,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.max_model_len, self.max_num_tokens), dtype=np.int64) - # NOTE(woosuk): These tensors are "stateless", i.e., they are literally - # a faster version of creating a new tensor every time. Thus, we should - # not make any assumptions about the values in these tensors. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.positions_np = self.positions_cpu.numpy() - self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.seq_lens_np = self.seq_lens_cpu.numpy() # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -344,13 +308,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, - ) if self.supports_mm_inputs \ - else None) + ) if self.supports_mm_inputs else None) self.reorder_batch_threshold: Optional[int] = None + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + + # Cached outputs. + self._draft_token_ids: Optional[Union[list[list[int]], + torch.Tensor]] = None + self.transfer_event = torch.cuda.Event() + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_model_len, 1), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(*args, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory) + def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() num_reqs = self.input_batch.num_reqs @@ -360,6 +342,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if num_pooling_reqs == 0: return model_kwargs + # This does nontrivial work. pooling_params = self.input_batch.pooling_metadata.pooling_params assert num_pooling_reqs == num_reqs @@ -374,7 +357,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if len(token_type_id_requests) == 0: return model_kwargs - seq_lens = self.seq_lens[:num_reqs] + seq_lens = self.seq_lens.gpu[:num_reqs] token_type_ids = [] for i in range(num_reqs): @@ -434,8 +417,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -446,12 +427,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.remove_request(req_id) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -468,7 +445,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) - req_ids_to_add: list[str] = [] + reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id @@ -483,18 +460,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): generator = None if pooling_params: - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" model = cast(VllmModelForPooling, self.get_model()) to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) - self.requests[req_id] = CachedRequestState( + req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -503,46 +481,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): output_token_ids=[], lora_request=new_req_data.lora_request, ) + self.requests[req_id] = req_state # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_item in self.requests[req_id].mm_kwargs: - mm_input = mm_item.get_data() - if mm_input.get("image_grid_thw") is not None: - image_grid_thw.append( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.append( - mm_input["video_grid_thw"].tolist()) - if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.append( - mm_input["second_per_grid_ts"]) - if mm_input.get("audio_feature_lengths") is not None: - audio_feature_lengths.append( - mm_input["audio_feature_lengths"]) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - hf_config = self.model_config.hf_config - - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + self._init_mrope_positions(req_state) - req_ids_to_add.append(req_id) + reqs_to_add.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank @@ -578,11 +523,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the block IDs. if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -592,13 +539,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) + reqs_to_add.append(req_state) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -625,9 +574,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - self.input_batch.add_request(req_state) + for request in reqs_to_add: + self.input_batch.add_request(request) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -636,42 +584,67 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_item in req_state.mm_kwargs: + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", ) -> BatchedTensorInputs: - if self.is_multimodal_raw_input_supported: # noqa: SIM102 - if scheduler_output: - mm_kwargs = list[MultiModalKwargsItem]() - for req in scheduler_output.scheduled_new_reqs: - req_mm_kwargs = req.mm_kwargs - if not isinstance(req_mm_kwargs, list): - req_mm_kwargs = list(req_mm_kwargs) - mm_kwargs.extend(req_mm_kwargs) - - # Input all modalities at once - mm_kwargs_combined: BatchedTensorInputs = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): - mm_kwargs_combined.update(mm_kwargs_group) + if not scheduler_output or not self.is_multimodal_raw_input_only_model: + return {} + + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + mm_kwargs.extend(req.mm_kwargs) - return mm_kwargs_combined + # Input all modalities at once + mm_kwargs_combined: BatchedTensorInputs = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + mm_kwargs_combined.update(mm_kwargs_group) - return {} + return mm_kwargs_combined def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: - if self.is_multimodal_raw_input_supported: - mm_budget = self.mm_budget - assert mm_budget is not None - - dummy_modality, _ = mm_budget.get_modality_with_max_tokens() + if not self.is_multimodal_raw_input_only_model: + return {} - return self._get_mm_dummy_batch(dummy_modality, num_seqs) + mm_budget = self.mm_budget + assert mm_budget is not None - return {} + dummy_modality = mm_budget.get_modality_with_max_tokens() + return self._get_mm_dummy_batch(dummy_modality, num_seqs) def _get_cumsum_and_arange( self, @@ -730,7 +703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens) # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] + positions_np = self.positions.np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) @@ -753,7 +726,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + out=self.input_ids.cpu[:total_num_scheduled_tokens]) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -761,42 +734,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): total_num_scheduled_tokens) # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[0] = 0 + self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] - self.seq_lens_np[:num_reqs] = ( + self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) + # Fill unused with 0 for full cuda graph mode. + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() + seq_lens = self.seq_lens.gpu[:num_reqs] + max_seq_len = self.seq_lens.np[:num_reqs].max().item() # Copy the tensors to the GPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True) else: # Common case (1D positions) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) - - self.query_start_loc[:num_reqs + 1].copy_( - self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) - - # Fill unused with 0 for full cuda graph mode. - self.seq_lens[num_reqs:].fill_(0) - # Note: pad query_start_loc to be non-decreasing, as kernels - # like FlashAttention requires that - self.query_start_loc[num_reqs + 1:].fill_( - self.query_start_loc_cpu[num_reqs].item()) - - query_start_loc = self.query_start_loc[:num_reqs + 1] - - spec_decode_common_attn_metadata = None + self.positions.copy_to_gpu(total_num_scheduled_tokens) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -824,73 +788,65 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: - assert self.kv_sharing_fast_prefill_logits_indices is not None - num_logits = logits_indices.shape[0] - assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices_padded = self._prepare_kv_sharing_fast_prefill( logits_indices) - # There might have leftover indices in logits_indices[num_logits:] - # from previous iterations, whose values may be greater than the - # batch size in the current iteration. To ensure indices are always - # valid, we fill the padded indices with the last index. - self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph( - num_logits) - else: - num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] - ) attn_metadata: dict[str, Any] = {} - # Prepare encoder attention metadata separately - # (encoder layers are not in KV cache groups) - if self.is_encoder_only_model: - - per_layer_metadata = \ - self._build_encoder_only_attn_metadata( - scheduler_output) - - # Add encoder attention metadata for all encoder layers - attention_layers = get_layers_from_vllm_config( - self.vllm_config, Attention) - for layer_name, attn_module in attention_layers.items(): - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - common_attn_metadata, encoder_attn_metadata =\ - per_layer_metadata[layer_name] - attn_metadata[layer_name] = encoder_attn_metadata + # Used in the below loop. + query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + spec_decode_common_attn_metadata = None # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] - - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens, ), + dtype=torch.int64, + device=self.device, + ) + num_common_prefix_blocks = 0 + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] + slot_mapping = blk_table.slot_mapping[: + total_num_scheduled_tokens] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = ( + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id]) common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), causal=True, ) @@ -905,8 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], + num_common_prefix_blocks, kv_cache_group_spec.kv_cache_spec, builder, ) @@ -916,28 +871,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): common_attn_metadata=common_attn_metadata, )) - fast_prefill_metadata = attn_metadata_i - if (self.cache_config.kv_sharing_fast_prefill - and self.kv_sharing_fast_prefill_eligible_layers): - # Dynamically create a a dataclass type that inherits - # from attention metadata type but includes additional - # fields logits_indices_padded and num_logits_indices - # which are required for prefill truncation - fast_prefill_metadata_type = ( - make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls=type(attn_metadata_i), )) - fast_prefill_metadata = fast_prefill_metadata_type( - **dataclasses.asdict(attn_metadata_i), - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), - ) - for layer_name in attn_group.layer_names: - if (self.cache_config.kv_sharing_fast_prefill - and layer_name - in self.kv_sharing_fast_prefill_eligible_layers): - attn_metadata[layer_name] = fast_prefill_metadata - continue attn_metadata[layer_name] = attn_metadata_i # Hot-Swap lora model @@ -1073,9 +1007,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - req.mrope_positions[:,src_start:src_end] - + self.mrope_positions.cpu[:, dst_start:dst_end] = ( + req.mrope_positions[:, src_start:src_end]) mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1084,7 +1017,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dst_end = mrope_pos_ptr + completion_part_len MRotaryEmbedding.get_next_input_positions_tensor( - out=self.mrope_positions_np, + out=self.mrope_positions.np, out_offset=dst_start, mrope_position_delta=req.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len, @@ -1159,7 +1092,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] - draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] metadata = SpecDecodeMetadata( @@ -1172,21 +1105,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return metadata + def _prepare_kv_sharing_fast_prefill( + self, + logits_indices: torch.Tensor, + ) -> torch.Tensor: + assert self.kv_sharing_fast_prefill_logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always + # valid, we fill the padded indices with the last index. + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item()) + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + return logits_indices_padded + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return - # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + mm_hash = req_state.mm_hashes[mm_input_id] mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1219,15 +1179,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for output in curr_group_outputs: encoder_outputs.append(output) - # Cache the encoder outputs. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( output, is_embed=pos_info.is_embed, ) @@ -1245,6 +1199,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1264,11 +1219,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) + num_encoder_tokens, + ) assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] + + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] @@ -1312,10 +1270,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): and "encode" in supported_tasks): supported_tasks.remove("encode") - logger.info_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") + logger.debug_once("Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it.") + + if "score" in supported_tasks: + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.debug_once( + "Score API is only enabled for num_labels == 1.") return supported_tasks @@ -1447,7 +1412,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model, is_dummy, is_profile, - log_stats=self.parallel_config.eplb_log_balancedness, + log_stats=self.parallel_config.eplb_config.log_balancedness, ) def get_dp_padding(self, @@ -1487,31 +1452,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "Either all or none of the requests in" \ " a batch must be pooling request" - extracted_hidden_states = list( - torch.split(hidden_states[:num_scheduled_tokens], - num_scheduled_tokens_np.tolist())) - + hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.pooling_metadata + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + # Pooling models D2H & synchronize occurs in pooler.py:build_output raw_pooler_output = self.model.pooler( - hidden_states=extracted_hidden_states, - pooling_metadata=pooling_metadata) + hidden_states=hidden_states, pooling_metadata=pooling_metadata) pooler_output: list[Optional[torch.Tensor]] = [] - seq_lens = self.seq_lens[:self.input_batch.num_reqs] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens, pooling_metadata.prompt_lens): + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - if seq_len == prompt_len: - pooler_output.append(raw_output.data.cpu()) - else: - pooler_output.append(None) + output = raw_output.data if seq_len == prompt_len else None + pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, @@ -1533,13 +1494,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, tokens, please disable it when the requests " + "need prompt logprobs") + # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = (self._prepare_inputs(scheduler_output)) + max_query_len) = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use CUDA graphs. # Add padding to the batch size. @@ -1574,7 +1542,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids[:num_scheduled_tokens], + input_ids=self.input_ids.gpu[:num_scheduled_tokens], multimodal_embeddings=mm_embeds or None, ) @@ -1593,13 +1561,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] + positions = self.mrope_positions.gpu[:, :num_input_tokens] else: - positions = self.positions[:num_input_tokens] + positions = self.positions.gpu[:num_input_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1632,6 +1600,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): batch_descriptor=batch_descriptor, ), self.maybe_get_kv_connector_output( scheduler_output) as kv_connector_output: + model_output = self.model( input_ids=input_ids, positions=positions, @@ -1746,7 +1715,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( hidden_states[:num_scheduled_tokens], - scheduler_output, + scheduler_output.num_scheduled_tokens, ) # Get the valid generated tokens. @@ -1755,7 +1724,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + valid_sampled_token_ids = self._to_list(sampled_token_ids) else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( @@ -1771,6 +1740,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. + req_ids = self.input_batch.req_ids for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): if not sampled_ids: continue @@ -1786,16 +1756,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if not self.speculative_config: - # Speculative decoding is not enabled. - spec_token_ids = None - else: + if self.speculative_config: assert spec_decode_common_attn_metadata is not None - spec_token_ids = self.propose_draft_token_ids( + self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, sampling_metadata, @@ -1812,7 +1779,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], @@ -1820,6 +1786,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_nans_in_logits=num_nans_in_logits, ) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + if self._draft_token_ids is None: + return None + req_ids = self.input_batch.req_ids + if isinstance(self._draft_token_ids, torch.Tensor): + draft_token_ids = self._draft_token_ids.tolist() + else: + draft_token_ids = self._draft_token_ids + self._draft_token_ids = None + return DraftTokenIds(req_ids, draft_token_ids) + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -1830,11 +1807,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]]: + ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self.propose_ngram_draft_token_ids( + draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) @@ -1852,13 +1829,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] - spec_token_ids = self.drafter.propose( + draft_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. + req_ids = self.input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): if token_ids: @@ -1867,7 +1845,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: # Partial prefill (rare case). # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -1879,9 +1857,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[:num_scheduled_tokens] + target_positions = self.positions.gpu[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1901,22 +1879,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.drafter.prepare_inputs( common_attn_metadata, num_rejected_tokens_cpu) - # num_accepted_tokens = [len(s) - 1 for s in sampled_token_ids] - # num_accepted_tokens_tensor = async_tensor_h2d( - # num_accepted_tokens, - # dtype=torch.int32, - # target_device=self.device, - # pin_memory=True) - - # num_accepted_tokens_cpu = torch.tensor(num_accepted_tokens, - # dtype=torch.int32) - # common_attn_metadata, token_indices =\ - # self.drafter.prepare_inputs( - # common_attn_metadata, num_accepted_tokens_cpu) - - target_token_ids = self.input_ids[token_indices] + target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[token_indices] + target_positions = self.positions.gpu[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1937,15 +1902,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_embeds=mm_embeds, decoding=spec_decode_metadata is not None, ) - spec_token_ids = draft_token_ids.tolist() - - return spec_token_ids + return draft_token_ids def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], ) -> list[list[int]]: # TODO(woosuk): Optimize. + req_ids = self.input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -1956,7 +1920,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Skip requests that require sampling parameters that are not # supported with speculative decoding. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] if req_id in self.input_batch.spec_decode_unsupported_reqs: draft_token_ids.append([]) continue @@ -2004,7 +1968,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): global_expert_load, old_global_expert_indices = ( EplbState.recv_state()) num_logical_experts = global_expert_load.shape[1] - self.parallel_config.num_redundant_experts = ( + self.parallel_config.eplb_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) assert old_global_expert_indices.shape[ 1] % num_local_physical_experts == 0 @@ -2104,7 +2068,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, - scheduler_output: "SchedulerOutput", + num_scheduled_tokens: dict[str, int], ) -> dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: @@ -2117,8 +2081,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): - - num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens[req_id] # Get metadata for this request. request = self.requests[req_id] @@ -2161,7 +2124,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If this is a partial request (i.e. chunked prefill), # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] - offset = self.query_start_loc_np[req_idx].item() + offset = self.query_start_loc.np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states, None) @@ -2234,7 +2197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @functools.cache def rand_input_ids() -> torch.Tensor: return torch.randint_like( - self.input_ids, + self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), dtype=input_ids.dtype) @@ -2251,19 +2214,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch return next(mm_kwargs_group for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - [dummy_mm_item] * max_items_per_batch, + dummy_mm_items, device=self.device, pin_memory=self.pin_memory, )) @@ -2277,6 +2244,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, + remove_lora: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -2289,11 +2257,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.FULL: Full cudagraph, attention metadata is needed. - force_attention: If True, always create attention metadata. Used to + force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. + remove_lora: If False, dummy LoRAs are not destroyed after the run """ assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL @@ -2350,31 +2319,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. - if force_attention or cudagraph_runtime_mode == \ - CUDAGraphMode.FULL: + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} # Make sure max_model_len is used at the graph capture time. - self.seq_lens_np[:num_reqs] = self.max_model_len - self.seq_lens_np[num_reqs:] = 0 - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) + self.seq_lens.np[:num_reqs] = self.max_model_len + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + seq_lens=self.seq_lens.gpu[:num_reqs], + seq_lens_cpu=self.seq_lens.cpu[:num_reqs], num_computed_tokens_cpu=self.input_batch. num_computed_tokens_cpu_tensor[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=self.max_model_len, num_speculative_tokens=num_speculative_tokens, block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], @@ -2389,7 +2357,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens): + num_scheduled_tokens, remove_lora): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2398,14 +2366,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): **self._dummy_mm_kwargs(num_reqs), } else: - input_ids = self.input_ids[:num_tokens] + input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] + positions = self.mrope_positions.gpu[:, :num_tokens] else: - positions = self.positions[:num_tokens] + positions = self.positions.gpu[:num_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -2560,13 +2528,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - hidden_states_list = list( - torch.split(hidden_states, num_scheduled_tokens_list)) req_num_tokens = num_tokens // num_reqs dummy_prompt_lens = torch.tensor( - [h.shape[0] for h in hidden_states_list], - device=self.device, + num_scheduled_tokens_list, + device="cpu", ) dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), dtype=torch.int32, @@ -2583,8 +2549,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooling_params=[dummy_pooling_params] * num_reqs, ) + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + try: - return model.pooler(hidden_states=hidden_states_list, + return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): @@ -2632,14 +2601,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] logger.info( "Encoder cache will be initialized with a budget of " @@ -2791,11 +2755,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, uniform_decode=uniform_decode, - skip_eplb=True) + skip_eplb=True, + remove_lora=False) self._dummy_run(num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=uniform_decode, - skip_eplb=True) + skip_eplb=True, + remove_lora=False) + self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -2803,20 +2770,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ assert len(self.attn_groups) == 0, \ "Attention backends are already initialized" - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) def get_attn_backends_for_layers( layer_names: list[str] ) -> dict[type[AttentionBackend], list[str]]: + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) attn_backends = {} attn_backend_layers = defaultdict(list) - # Dedupe based on full class name; this is a bit safer than using + # Dedupe based on full class name; this is a bit safer than # using the class itself as the key because when we create dynamic # attention backend subclasses (e.g. ChunkedLocalAttention) unless # they are cached correctly, there will be different objects per # layer. for layer_name in layer_names: - attn_backend = attn_layers[layer_name].get_attn_backend() + attn_backend = layers[layer_name].get_attn_backend() + + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, + ) + key = attn_backend.full_cls_name() attn_backends[key] = attn_backend attn_backend_layers[key].append(layer_name) @@ -2845,69 +2821,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - attn_backends = get_attn_backends_for_layers( - kv_cache_group_spec.layer_names) - # TODO(lucas): move `get_mamba_attn_backend` into the mamba - # layers like above - elif isinstance(kv_cache_spec, MambaSpec): - attn_backends = { - get_mamba_attn_backend(kv_cache_spec.mamba_type): - kv_cache_group_spec.layer_names - } - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") - + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) self.attn_groups.append( create_attn_groups(attn_backends, kv_cache_spec)) # Calculate reorder batch threshold (if neeeded) self.calculate_reorder_batch_threshold() - if len(self.attn_groups) > 0: - return - - # Check if model is encoder-only - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) - for layer_name, attn_module in attn_layers.items(): - - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - if attn_module.sliding_window is None: - attn_spec: AttentionSpec = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - else: - attn_spec = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) - attn_specs[attn_spec].append(layer_name) - - else: - raise ValueError("Expected only encoder-only layers") - - if len(attn_specs) > 0: - total_layers = 0 - for attn_spec, layer_names in attn_specs.items(): - - attn_backends = get_attn_backends_for_layers(layer_names) - total_layers += len(layer_names) - - self.attn_groups.append( - create_attn_groups(attn_backends, attn_spec)) - assert total_layers == len(attn_layers), \ - "All or none of the layers are expected to be encoder-only" - self.is_encoder_only_model = True - def initialize_cudagraph_capture(self) -> None: min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None @@ -3055,7 +2976,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): layer_names = set() for group in kv_cache_config.kv_cache_groups: - layer_names.update(group.layer_names) + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) assert layer_names == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors @@ -3083,7 +3007,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: kv_cache_config: The KV cache config kv_cache_raw_tensors: The KV cache buffer of each layer, with - correct size but uninitialized shape. + correct size but uninitialized shape. Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. @@ -3093,6 +3017,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): attn_backend = group.backend for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // @@ -3205,40 +3131,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): raise NotImplementedError if has_attn and has_mamba: - self._verify_hybrid_attention_mamba_layout(kv_cache_config, - kv_cache_raw_tensors) + self._update_hybrid_attention_mamba_layout(kv_caches) return kv_caches - def _verify_hybrid_attention_mamba_layout( - self, kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor]) -> None: """ - Verify that the KV cache memory layout is compatible for - models with both attention and mamba KV cache groups. + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). Args: - kv_cache_config: The KV cache config - kv_cache_raw_tensors: The KV cache buffer of each layer. + kv_caches: The KV cache buffer of each layer. """ for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): for layer_name in group.layer_names: - raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, AttentionSpec): - - kv_cache_shape = group.backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if kv_cache_shape[0] != num_blocks or kv_cache_shape[ - 1] != 2: - raise ValueError( - "Hybrid models in V1 require an attention " - "backend with kv_cache_shape=" - "(num_blocks, 2, ...). Please try setting " - "VLLM_ATTENTION_BACKEND=FLASHINFER") + kv_cache = kv_caches[layer_name] + if (isinstance(kv_cache_spec, AttentionSpec) + and kv_cache.shape[0] == 2): + assert kv_cache.shape[1] != 2, \ + "Fail to determine whether the layout is " \ + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + f"a tensor of shape {kv_cache.shape}" + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_(size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, + *kv_cache.stride()[2:])) def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: @@ -3257,19 +3176,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - self.attn_groups, - ) + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def maybe_add_kv_sharing_layers_to_kv_cache_groups( + self, kv_cache_config: KVCacheConfig) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + self.runner_only_attn_layers, + ) + + if self.cache_config.kv_sharing_fast_prefill: + # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other + # similar KV sharing setups, only the layers that generate KV caches + # are involved in the prefill phase, enabling prefill to early exit. attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - # Iterate in reversed order and add layers that re-use KV cache - # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: self.kv_sharing_fast_prefill_eligible_layers.add( @@ -3277,11 +3217,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: break - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches) - return kv_caches - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -3289,8 +3224,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -3304,6 +3242,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -3393,64 +3358,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return kv_cache_spec - def _build_encoder_only_attn_metadata( - self, scheduler_output: "SchedulerOutput") -> \ - dict[str, tuple[CommonAttentionMetadata, Any]]: - """Prepare encoder attention metadata for encoder-only models. - - Args: - scheduler_output: Scheduler output - - Returns: - dict[str, Any]: Encoder attention metadata - """ - num_reqs = self.input_batch.num_reqs - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - max_num_scheduled_tokens = max(tokens) - - dummy_block_table = torch.zeros((num_reqs, 1), - dtype=torch.int32, - device=self.device) - dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), - dtype=torch.int32, - device=self.device) - - group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]() - - for attn_group_list in self.attn_groups: - - assert len(attn_group_list) == 1 - attn_group = attn_group_list[0] - - # Use the first attention metadata builder - # to create encoder attention metadata - builder = attn_group.metadata_builder - - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - block_table_tensor=dummy_block_table, - slot_mapping=dummy_slot_mapping, - causal=False, - ) - - metadata = builder.build( - common_prefix_len=0, # No cascade for encoder - common_attn_metadata=common_metadata, - ) - - for layer_name in attn_group.layer_names: - group_metadata[layer_name] = (common_metadata, metadata) - - return group_metadata + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a cuda wise stream sync, which + # would block other copy ops from other cuda streams. + # A cuda event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return pinned.tolist() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a1174fa0663c533fdbe399f2e4eac7b8775007fe..245703f6dfc63d103ba90c9dd78e81fd7c9516f1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -28,7 +28,8 @@ from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, + ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -168,7 +169,7 @@ class Worker(WorkerBase): self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) - _check_if_gpu_supports_dtype(self.model_config.dtype) + current_platform.check_if_supports_dtype(self.model_config.dtype) gc.collect() torch.cuda.empty_cache() @@ -222,8 +223,7 @@ class Worker(WorkerBase): self.model_runner.update_config(overrides) def reload_weights(self) -> None: - with self._maybe_get_memory_pool_context(tag="weights"): - self.model_runner.reload_weights() + self.model_runner.reload_weights() @torch.inference_mode() def determine_available_memory(self) -> int: @@ -231,7 +231,7 @@ class Worker(WorkerBase): memory can be used for KV cache without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the free memory that can be used for KV cache in + Then, it calculates the free memory that can be used for KV cache in bytes. Tip: @@ -298,7 +298,6 @@ class Worker(WorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.initialize_kv_cache(kv_cache_config) @@ -316,7 +315,14 @@ class Worker(WorkerBase): # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size, skip_eplb=True) + self.model_runner._dummy_run(size, + skip_eplb=True, + remove_lora=False) + self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) + + # Warmup and tune the kernels used during model execution before + # cuda graph capture. + kernel_warmup(self) if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -342,9 +348,6 @@ class Worker(WorkerBase): self.model_runner._dummy_sampler_run( hidden_states=last_hidden_states) - # Warmup kernels used during model execution - kernel_warmup(self) - # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) @@ -361,7 +364,8 @@ class Worker(WorkerBase): scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: intermediate_tensors = None - if not get_pp_group().is_first_rank: + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) @@ -375,30 +379,34 @@ class Worker(WorkerBase): output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - parallel_config = self.vllm_config.parallel_config - if parallel_config.distributed_executor_backend != "external_launcher" \ - and not get_pp_group().is_last_rank: - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - - kv_connector_output = output.kv_connector_output - if not kv_connector_output: - return None - - # In case of PP with kv transfer, we need to pass through the - # kv_connector_output - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output + if isinstance(output, ModelRunnerOutput): return output - assert isinstance(output, ModelRunnerOutput) + assert isinstance(output, IntermediateTensors) + parallel_config = self.vllm_config.parallel_config + assert parallel_config.distributed_executor_backend != ( + "external_launcher") and not get_pp_group().is_last_rank + + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + + kv_connector_output = output.kv_connector_output + if not kv_connector_output: + return None + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output return output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + return self.model_runner.take_draft_token_ids() + def profile(self, is_start: bool = True): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -524,7 +532,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None new_physical_experts = \ self.model_runner.eplb_state.physical_to_logical_map.shape[1] - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - self.model_runner.eplb_state.logical_replica_count.shape[1]) global_expert_load = None @@ -540,7 +548,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=False) - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - global_expert_load.shape[1]) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata( @@ -623,23 +631,3 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 2fbdee4724e35fc1e6bf88fb8ce87c4064560f78..4b5f27d27541b9cca09cbb1a8b9bc94cdc21bb8b 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -5,9 +5,10 @@ Define LoRA functionality mixin for model runners. """ from contextlib import contextmanager -from typing import Union +from typing import Optional, Union import numpy as np +import torch import torch.nn as nn from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig @@ -31,7 +32,8 @@ class LoRAModelRunnerMixin: def load_lora_model(self, model: nn.Module, model_config: ModelConfig, scheduler_config: SchedulerConfig, - lora_config: LoRAConfig, device: str) -> nn.Module: + lora_config: LoRAConfig, + device: torch.device) -> nn.Module: if not supports_lora(model): raise ValueError( @@ -85,7 +87,9 @@ class LoRAModelRunnerMixin: lora_requests) @contextmanager - def maybe_setup_dummy_loras(self, lora_config): + def maybe_setup_dummy_loras(self, + lora_config: Optional[LoRAConfig], + remove_lora: bool = True): if lora_config is None: yield else: @@ -112,10 +116,11 @@ class LoRAModelRunnerMixin: yield # __exit__ code - self.lora_manager.remove_all_adapters() + if remove_lora: + self.lora_manager.remove_all_adapters() @contextmanager - def maybe_select_dummy_loras(self, lora_config: LoRAConfig, + def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray): if lora_config is None: yield @@ -149,13 +154,22 @@ class LoRAModelRunnerMixin: yield @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): - with self.maybe_setup_dummy_loras( - lora_config), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens): + def maybe_dummy_run_with_lora(self, + lora_config: Optional[LoRAConfig], + num_scheduled_tokens: np.ndarray, + remove_lora: bool = True): + with ( + self.maybe_setup_dummy_loras(lora_config, remove_lora), + self.maybe_select_dummy_loras(lora_config, + num_scheduled_tokens), + ): yield + def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]): + if lora_config is None: + return + self.lora_manager.remove_all_adapters() + def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f7e68edba3a13cdd254b3319e68d9036751ae17f..985d5ba58c49c6a8bdf3d88b62c17d4235ec324b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -55,9 +55,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (MultiModalBudget, bind_kv_cache, - initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs) +from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, sanity_check_mm_encoder_outputs) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -208,8 +207,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -292,8 +291,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, ) if self.supports_mm_inputs else None) if not self.use_spmd: @@ -344,7 +341,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -359,12 +355,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): removed_req_indices.append(req_index) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -396,6 +388,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): prompt_token_ids=new_req_data.prompt_token_ids, mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -418,11 +411,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the cached states. req_state.num_computed_tokens = num_computed_tokens if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -438,7 +433,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -554,7 +551,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return kv_cache_spec def _get_slot_mapping_metadata(self, num_reqs, - num_scheduled_tokens_per_req): + num_scheduled_tokens_per_req) -> np.ndarray: """ Computes metadata for mapping slots to blocks in the key-value (KV) cache for a batch of requests. @@ -567,15 +564,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: num_reqs (int): Number of requests in the current batch. num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens - to be scheduled for each request. + to be scheduled for each request. Returns: np.ndarray: A 2D array of shape (total_block_len, 3), where each row - contains: + contains: - kv_cache_start_index (int): The starting index in the KV cache - for the corresponding slice. + for the corresponding slice. - new_kv_start_index (int): The starting index in the new KV - cache for the corresponding slice. + cache for the corresponding slice. - slice_len (int): The length of the slice. """ slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] @@ -811,31 +808,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ num_reqs, end_index - def _scatter_placeholders( - self, - embeds: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return embeds - - placeholders = embeds.new_full( - (is_embed.shape[0], embeds.shape[-1]), - fill_value=torch.nan, - ) - placeholders[is_embed] = embeds - return placeholders - - def _gather_placeholders( - self, - placeholders: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return placeholders - - return placeholders[is_embed] - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: @@ -843,14 +815,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # List of tuple (mm_hash, pos_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + mm_hash = req_state.mm_hashes[mm_input_id] mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -893,15 +867,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE (NickLucche) here we diverge from logic in other runners, as we # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): assert pos_info.is_embed is None, "Expected all positions to be"\ " contiguous and embeddings." - self.encoder_cache[req_id][input_id] = output + self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( self, @@ -914,6 +883,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid @@ -934,11 +904,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the decoder's KV cache. continue - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." assert pos_info.is_embed is None, "Expected all positions to"\ " be contiguous and embeddings." - encoder_output = self.encoder_cache[req_id][i] + encoder_output = self.encoder_cache[mm_hash] mm_embeds.append(encoder_output) return mm_embeds @@ -1145,7 +1117,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], @@ -1542,14 +1513,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] logger.info( "Encoder cache will be initialized with a budget of " @@ -1602,6 +1568,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.encoder_cache.clear() gc.collect() + def maybe_setup_cross_layer_kv_sharing( + self, + kv_caches: dict[str, torch.Tensor], + kv_cache_config: KVCacheConfig, + ) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + ) + + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1667,14 +1657,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: raise NotImplementedError - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - ) + # Set up cross-layer KV cache sharing if needed + self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) bind_kv_cache( kv_caches, @@ -1816,19 +1800,23 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch return next(grouped_mm_kwargs for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( - [dummy_mm_item] * max_items_per_batch, + dummy_mm_items, device=self.device, pin_memory=self.pin_memory, )) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 72e0e4230a0179eae8e742fb8f1d3a9e3062724f..9adf8a14213f353d178edac6764259d74b8a2081 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" + import os from typing import Any, Optional import torch import torch.distributed import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.debug.profiler as xp -import torch_xla.runtime as xr import vllm.envs as envs from vllm.config import VllmConfig @@ -21,19 +19,27 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats -from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") + import torch_xla.core.xla_model as xm + import torch_xla.debug.profiler as xp + import torch_xla.runtime as xr + + from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT + from vllm.v1.worker.tpu_model_runner import TPUModelRunner + class TPUWorker: @@ -325,9 +331,7 @@ class TPUWorker: ensure_kv_transfer_initialized(vllm_config) -try: +if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TPUWorker.") - pass diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b138f11af1eb1b7e9750598360f2f10b2e5bd339..6767804c71b9f422b8015a5096bfd49e150763da 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -10,9 +10,10 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import ModelConfig, SchedulerConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.v1.attention.backends.utils import AttentionMetadataBuilder -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec if TYPE_CHECKING: @@ -27,35 +28,36 @@ class MultiModalBudget: model_config: ModelConfig, scheduler_config: SchedulerConfig, mm_registry: MultiModalRegistry, - *, - max_model_len: int, - max_num_reqs: int, ) -> None: super().__init__() self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry + self.cache = cache = processor_only_cache_from_config( + model_config, mm_registry) - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=mm_registry, + self.max_model_len = model_config.max_model_len + self.max_num_reqs = scheduler_config.max_num_seqs + + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, + cache=cache) + + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config, + cache=cache) + + encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, ) - self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_compute_budget = encoder_compute_budget self.encoder_cache_size = encoder_cache_size - self.max_model_len = max_model_len - self.max_num_reqs = max_num_reqs - - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) max_items_per_prompt_by_modality = dict[str, int]() max_items_per_batch_by_modality = dict[str, int]() - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) - for modality, max_tokens in max_tokens_by_modality.items(): ( max_items_per_prompt, @@ -69,15 +71,14 @@ class MultiModalBudget: self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality self.max_items_per_batch_by_modality = max_items_per_batch_by_modality - def get_modality_with_max_tokens(self) -> tuple[str, int]: + def get_modality_with_max_tokens(self) -> str: max_tokens_by_modality = self.max_tokens_by_modality - modality, max_tokens = max(max_tokens_by_modality.items(), - key=lambda item: item[1]) + modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1]) - return modality, max_tokens + return modality def get_encoder_budget(self) -> int: - return min(self.max_num_encoder_input_tokens, self.encoder_cache_size) + return min(self.encoder_compute_budget, self.encoder_cache_size) def get_max_items( self, @@ -171,10 +172,10 @@ def scatter_mm_placeholders( Args: embeds: The multimodal embeddings. - Shape: `(num_embeds, embed_dim)` + Shape: `(num_embeds, embed_dim)` is_embed: A boolean mask indicating which positions in the placeholder - tokens need to be filled with multimodal embeddings. - Shape: `(num_placeholders, num_embeds)` + tokens need to be filled with multimodal embeddings. + Shape: `(num_placeholders, num_embeds)` """ if is_embed is None: return embeds @@ -202,12 +203,10 @@ def gather_mm_placeholders( return placeholders[is_embed] -def initialize_kv_cache_for_kv_sharing( +def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], - kv_caches: dict[str, torch.Tensor], - # Optional for now to avoid breaking TPU - attn_groups: Optional[list[list[AttentionGroup]]] = None, + runner_only_attn_layers: Optional[set[str]] = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -221,38 +220,18 @@ def initialize_kv_cache_for_kv_sharing( means this layer will perform attention using the keys and values from the KV cache of `shared_kv_cache_layers[layer_name]`. kv_cache_groups: The KV cache groups of the model. - kv_caches: The allocated kv_caches with layer names as keys. - Note that layers in shared_kv_cache_layers.keys() are not - originally included as it only contains layers which have its own - KV cache allocation. - attn_groups: Optional list of attention groups. Layers in the same KV - cache group may be placed in different attention groups if they - have different attention backends. Currently only provided by - GPU model runner. """ - # mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx) - layer_to_attn_group_idx: dict[str, tuple[int, int]] = {} - if attn_groups: - for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups): - for attn_group_idx, attn_group in enumerate(kv_attn_groups): - for layer_name in attn_group.layer_names: - layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, - attn_group_idx) - else: - for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups): - for layer_name in kv_cache_group.layer_names: - # attn group idx default to 0 if not provided - layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0) + layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {} + for kv_cache_group in kv_cache_groups: + for layer_name in kv_cache_group.layer_names: + layer_to_kv_cache_group[layer_name] = kv_cache_group for layer_name, target_layer_name in shared_kv_cache_layers.items(): - kv_caches[layer_name] = kv_caches[target_layer_name] - kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0] - kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name) + tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] + tgt_kv_cache_group.layer_names.append(layer_name) - if attn_groups: - attn_group_idx = layer_to_attn_group_idx[target_layer_name][1] - attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append( - layer_name) + if runner_only_attn_layers is not None: + runner_only_attn_layers.add(layer_name) def bind_kv_cache( @@ -273,7 +252,7 @@ def bind_kv_cache( Args: kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention - layers with layer names as keys. + layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ # Bind kv_caches to ModelRunner diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index a79317cbb72d013e1438b1715c62adcc8757accc..84228be3bc01f0bbf1153e684da677f8a90b2de7 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -36,8 +36,8 @@ class WorkerBase(WorkerBaseV0): local_rank: Local device index rank: Global rank in distributed setup distributed_init_method: Distributed initialization method - is_driver_worker: Whether this worker handles driver - responsibilities + is_driver_worker: Whether this worker handles driver + responsibilities """ # Configuration storage super().__init__(vllm_config=vllm_config) diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 59f8d0fcf5bd9786ab0f9709e17b508ceb776335..fb892211f19dbd46205e1c504e27435e9d6ebe2f 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager from typing import TYPE_CHECKING import torch @@ -22,7 +23,8 @@ class XPUModelRunner(GPUModelRunner): vllm_config: VllmConfig, device: torch.device, ): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False @@ -31,3 +33,21 @@ class XPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: torch.xpu.synchronize() + + +@contextmanager +def _torch_cuda_wrapper(): + + class _EventPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + try: + # replace cuda Event with xpu Event, this should work by default + torch.cuda.Event = torch.xpu.Event + yield + finally: + # if anything goes wrong, just patch it with a placeholder + torch.cuda.Event = _EventPlaceholder diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 134d8392526531ffd4c8b779b3cca375e0c56177..17288cda8eccff905426123c3850467aef631042 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -145,6 +145,7 @@ class XPUWorker(Worker): ): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) + current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank).total_memory diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index cb5d5664ab5c0b0557be56a49f3092ebe8c1ed3f..12fd25f4de2ad2047f432d5f8bf2d9590a3aa31e 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -24,8 +24,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.platforms import _Backend from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, @@ -161,7 +160,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, - ) -> Optional[List[PoolerOutput]]: + ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4a76891b9a15f1fbc6b1eba9542b2b125d9b4344..58afafc34e56e59501baaba0d3992c0482494b62 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -88,7 +88,6 @@ class ModelInputForGPU(ModelRunnerInputBase): input_tokens: Optional[torch.Tensor] = None inputs_embeds: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None - token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None @@ -196,7 +195,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.input_tokens[0].clear() # type: ignore self.inputs_embeds = None # type: ignore self.input_positions[0].clear() # type: ignore - self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore @@ -223,7 +221,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): input_tokens: Optional[List[List[int]]] = None, inputs_embeds: Optional[torch.Tensor] = None, input_positions: Optional[List[List[int]]] = None, - token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). @@ -288,12 +285,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() - if token_types: - self.token_types = token_types - else: - for seq_id in range(len(self.seq_ids)): - self.token_types[seq_id].clear() - self.mrope_input_positions = None if seq_lens: @@ -351,7 +342,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.input_tokens = input_tokens or [] self.inputs_embeds = inputs_embeds self.input_positions = input_positions or [] - self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] @@ -379,7 +369,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] - self.token_types = [[] for _ in range(self.n_seqs)] self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs @@ -403,7 +392,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): f"inputs_embeds.shape=" f"{getattr(self.inputs_embeds, 'shape', None)}, " f"input_positions={self.input_positions}, " - f"token_types={self.token_types}, " f"mrope_input_positions={self.mrope_input_positions}, " f"seq_lens={self.seq_lens}, " f"orig_seq_lens={self.orig_seq_lens}, " @@ -527,8 +515,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): prompt_embeds = seq_data.get_token_embeddings( )[context_len:seq_len] - token_types = seq_group_metadata.token_type_ids - inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() @@ -537,8 +523,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): inter_data.inputs_embeds = prompt_embeds # inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) - inter_data.token_types[seq_idx].extend( - token_types if token_types else []) + inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: @@ -596,8 +581,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][uncomputed_start:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - uncomputed_start:] context_len = prefix_cache_len inter_data.context_lens[seq_idx] = context_len @@ -612,8 +595,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): seq_idx][-1:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][-1:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - -1:] inter_data.query_lens[seq_idx] = 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 @@ -808,12 +789,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): # Combine and flatten intermediate data. input_tokens = list[int]() inputs_embeds_list = list[torch.Tensor]() - token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) - for cur_token_types in inter_data.token_types: - token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: inputs_embeds_list.append( inter_data.inputs_embeds.to( @@ -900,11 +878,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.runner.device, self.runner.pin_memory) - token_types_tensor = async_tensor_h2d(token_types, torch.long, - self.runner.device, - self.runner.pin_memory) \ - if token_types else None - if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( @@ -961,7 +934,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): input_tokens=input_tokens_tensor, inputs_embeds=inputs_embeds, input_positions=input_positions_tensor, - token_types=token_types_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 7b8fe2f802d687f03ac4a3238506878da6f66dc8..1008b743619a4e335d56f69682a3826bc4b26fa2 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -13,10 +13,9 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.interfaces import supports_transcription -from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, is_text_generation_model) +from vllm.model_executor.models.interfaces_base import is_text_generation_model from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.tasks import GenerationTask, SupportedTask if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -241,20 +240,11 @@ class ModelRunnerBase(ABC, Generic[T]): return supported_tasks - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - model = self.get_model() - if not is_pooling_model(model): - return [] - - return list(model.pooler.get_supported_tasks()) - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks = list[SupportedTask]() if self.model_config.runner_type == "generate": tasks.extend(self.get_supported_generation_tasks()) - if self.model_config.runner_type == "pooling": - tasks.extend(self.get_supported_pooling_tasks()) return tuple(tasks) diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py deleted file mode 100644 index e49783ad9b244cb024a8ca5ceb5bbdc1e3233a0b..0000000000000000000000000000000000000000 --- a/vllm/worker/pooling_model_runner.py +++ /dev/null @@ -1,214 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast - -import torch - -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.model_executor.models.interfaces_base import VllmModelForPooling -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.multimodal import MultiModalKwargs -from vllm.pooling_params import PoolingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, - SequenceGroupMetadata) -from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, - ModelInputForGPUBuilder) - -logger = init_logger(__name__) - - -@dataclasses.dataclass(frozen=True) -class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): - """ - Used by the PoolingModelRunner. - """ - pooling_metadata: Optional["PoolingMetadata"] = None - - -class PoolingModelRunner( - GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): - _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( - ModelInputForGPUWithPoolingMetadata) - _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - ): - super().__init__(vllm_config=vllm_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForGPUWithPoolingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError( - "PoolingModelRunner does not support multi-step execution.") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - virtual_engine = model_input.virtual_engine - # Pooling models are (ab-)used also to integrate non text models that - # are not autoregressive (PrithviGeosaptialMAE). - # These model might not use attention and do not really have a prefill - # and decode phase. The model input is processed in one shot and both - # decode_metadata and prefill_metadata would be None for such models. - # See the PlaceholderAttentionMetadata class. - # TODO: Figure out if cuda_graph is of any use for these models and - # explore how to leverage it. - if (prefill_meta is None and decode_meta is not None - and decode_meta.use_cuda_graph): - if model_input.inputs_embeds is None: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, False)]) - else: - graph_batch_size = model_input.inputs_embeds.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, True)]) - else: - model_executable = self.model - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start = torch.cuda.Event(enable_timing=True) - model_forward_end = torch.cuda.Event(enable_timing=True) - model_forward_start.record() - - cross_enc_kwargs = {} - if model_input.token_types is not None: - cross_enc_kwargs["token_type_ids"] = model_input.token_types - - with set_forward_context(model_input.attn_metadata, self.vllm_config, - virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **cross_enc_kwargs, - **seqlen_agnostic_kwargs, - ) - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.record() - - # Only perform pooling in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) - return hidden_or_intermediate_states - - # Only perform pooling in the driver worker. - if not self.is_driver_worker: - return [] - - return [ - self.model.pooler(hidden_states=hidden_or_intermediate_states, - pooling_metadata=model_input.pooling_metadata) - ] - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> ModelInputForGPUWithPoolingMetadata: - return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithPoolingMetadata: - assert seq_group_metadata_list is not None - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - # Prepare PoolingMetadata. - assert model_input.seq_lens is not None - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) - - return dataclasses.replace(model_input, - pooling_metadata=pooling_metadata) - - def _prepare_pooling( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> PoolingMetadata: - """Prepare PoolingMetadata for the sequence group metadata list.""" - seq_groups: List[Tuple[List[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - - pooling_params = seq_group_metadata.pooling_params - assert pooling_params is not None - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") - - model = cast(VllmModelForPooling, self.model) - to_update = model.pooler.get_pooling_updates(task) - to_update.apply(pooling_params) - - seq_groups.append((seq_ids, pooling_params)) - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - pooling_metadata = PoolingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - ) - - return pooling_metadata diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 68e3605af93b89fa66a69b0098a2ff6c6ccbcad6..8b5dacea1687886439c84a9b84093adc9a1c0ca5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -3,6 +3,7 @@ """A GPU worker class.""" import gc import os +from contextlib import nullcontext from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -29,7 +30,6 @@ from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -77,13 +77,12 @@ class Worker(LocalOrDistributedWorkerBase): "eagle", "deepseek_mtp", "glm4_moe_mtp", - "mimo_mtp")) \ + "mimo_mtp", + "ernie_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_config.runner_type == "pooling": - ModelRunnerClass = PoolingModelRunner - elif self.model_config.is_encoder_decoder: + if self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( vllm_config=self.vllm_config, @@ -97,7 +96,6 @@ class Worker(LocalOrDistributedWorkerBase): # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] - # Initialize gpu_cache as pooling models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} @@ -205,7 +203,6 @@ class Worker(LocalOrDistributedWorkerBase): "used for one instance per process.") context = allocator.use_memory_pool(tag="weights") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.load_model() @@ -329,7 +326,6 @@ class Worker(LocalOrDistributedWorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self._init_cache_engine() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b6c2ddb551a49b52cfb8f0dcb6b1b8f8b1c87985..db7f357a182382438d39b74d6004fa59f4ee5f21 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -594,7 +594,7 @@ class WorkerWrapperBase: Arguments are passed to the worker class constructor. """ kwargs = all_kwargs[self.rpc_rank] - self.vllm_config = kwargs.get("vllm_config", None) + self.vllm_config = kwargs.get("vllm_config") assert self.vllm_config is not None, ( "vllm_config is required to initialize the worker") enable_trace_function_call_for_thread(self.vllm_config)